core: Handle MCTP fragment sizes

Message assembly can be terminated if case we receive a
middle/end packet of unexpected size. This provision is
provided in DSP0236 v1.3.1 section 8.8 incorrect transmission
unit.

Reception of middle packets whose size is not equal to
start packet and end packets whose size is not less than
or equal to the start packet causes message assembly
termination.

Signed-off-by: Sumanth Bhat <sumanth.bhat@linux.intel.com>
Change-Id: I6371ab9e22e2c8ece70a9480f224de3f1f2f184e
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d13bab4..fa8999f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,10 @@
 target_link_libraries (test_cmds mctp)
 add_test (NAME control_commands COMMAND test_cmds)
 
+add_executable (test_core tests/test_core.c tests/test-utils.c)
+target_link_libraries (test_core mctp)
+add_test (NAME core COMMAND test_core)
+
 install (TARGETS mctp DESTINATION lib)
 install (FILES libmctp.h DESTINATION include)
 
diff --git a/Makefile.am b/Makefile.am
index 5306074..ab3c80e 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -47,7 +47,8 @@
 TESTS = $(check_PROGRAMS)
 
 check_PROGRAMS = tests/test_eid tests/test_seq tests/test_bridge \
-		 tests/test_astlpc tests/test_serial tests/test_cmds
+		 tests/test_astlpc tests/test_serial tests/test_cmds \
+		 tests/test_core
 # We set a global LDADD here, as there's no way to specify it for all
 # tests. This means other targets' LDADDs need to be overridden.
 LDADD = tests/libtest-utils.a libmctp.la
diff --git a/core.c b/core.c
index 5ec4fcb..84c25d7 100644
--- a/core.c
+++ b/core.c
@@ -45,6 +45,7 @@
 	void		*buf;
 	size_t		buf_size;
 	size_t		buf_alloc_size;
+	size_t		fragment_size;
 };
 
 struct mctp {
@@ -215,6 +216,7 @@
 static void mctp_msg_ctx_reset(struct mctp_msg_ctx *ctx)
 {
 	ctx->buf_size = 0;
+	ctx->fragment_size = 0;
 }
 
 static int mctp_msg_ctx_add_pkt(struct mctp_msg_ctx *ctx,
@@ -540,6 +542,10 @@
 					hdr->src, hdr->dest, tag);
 		}
 
+		/* Save the fragment size, subsequent middle fragments
+		 * should of the same size */
+		ctx->fragment_size = mctp_pktbuf_size(pkt);
+
 		rc = mctp_msg_ctx_add_pkt(ctx, pkt, mctp->max_message_size);
 		if (rc) {
 			mctp_msg_ctx_drop(ctx);
@@ -564,6 +570,16 @@
 			goto out;
 		}
 
+		len = mctp_pktbuf_size(pkt);
+
+		if (len > ctx->fragment_size) {
+			mctp_prdebug("Unexpected fragment size. Expected" \
+				" less than %zu, received = %zu",
+				ctx->fragment_size, len);
+			mctp_msg_ctx_drop(ctx);
+			goto out;
+		}
+
 		rc = mctp_msg_ctx_add_pkt(ctx, pkt, mctp->max_message_size);
 		if (!rc)
 			mctp_rx(mctp, bus, ctx->src, ctx->dest,
@@ -587,6 +603,15 @@
 			goto out;
 		}
 
+		len = mctp_pktbuf_size(pkt);
+
+		if (len != ctx->fragment_size) {
+			mctp_prdebug("Unexpected fragment size. Expected = %zu " \
+				"received = %zu", ctx->fragment_size, len);
+			mctp_msg_ctx_drop(ctx);
+			goto out;
+		}
+
 		rc = mctp_msg_ctx_add_pkt(ctx, pkt, mctp->max_message_size);
 		if (rc) {
 			mctp_msg_ctx_drop(ctx);
diff --git a/tests/test_core.c b/tests/test_core.c
new file mode 100644
index 0000000..42d9187
--- /dev/null
+++ b/tests/test_core.c
@@ -0,0 +1,400 @@
+/* SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later */
+
+#define _GNU_SOURCE
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#if HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <assert.h>
+#include <fcntl.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+#include "libmctp-alloc.h"
+#include "libmctp-log.h"
+#include "range.h"
+#include "test-utils.h"
+
+#define TEST_DEST_EID 9
+#define TEST_SRC_EID  10
+
+#ifndef ARRAY_SIZE
+#define ARRAY_SIZE(a) (sizeof(a) / sizeof(a[0]))
+#endif
+
+#define __unused __attribute__((unused))
+
+#define MAX_PAYLOAD_SIZE 50000
+
+struct pktbuf {
+	struct mctp_hdr hdr;
+	uint8_t *payload;
+};
+
+struct test_params {
+	bool seen;
+	size_t message_size;
+};
+
+static void rx_message(uint8_t eid __unused, void *data, void *msg __unused,
+		       size_t len)
+{
+	struct test_params *param = (struct test_params *)data;
+
+	mctp_prdebug("MCTP message received: len %zd", len);
+
+	param->seen = true;
+	param->message_size = len;
+}
+
+static uint8_t get_sequence()
+{
+	static uint8_t pkt_seq = 0;
+
+	return (pkt_seq++ % 4);
+}
+
+static uint8_t get_tag()
+{
+	static uint8_t tag = 0;
+
+	return (tag++ % 8);
+}
+
+/*
+ * receive_pktbuf bypasses all bindings and directly invokes mctp_bus_rx.
+ * This is necessary in order invoke test cases on the core functionality.
+ * The memory allocated for the mctp packet is capped at MCTP_BTU
+ * size, however, the mimiced rx pkt still retains the len parameter.
+ * This allows to mimic packets larger than a sane memory allocator can
+ * provide.
+ */
+static void receive_ptkbuf(struct mctp_binding_test *binding,
+			   const struct pktbuf *pktbuf, size_t len)
+{
+	size_t alloc_size = MIN((size_t)MCTP_BTU, len);
+	struct mctp_pktbuf *rx_pkt;
+
+	rx_pkt = __mctp_alloc(sizeof(*rx_pkt) + MCTP_PACKET_SIZE(alloc_size));
+	assert(rx_pkt);
+
+	/* Preserve passed len parameter */
+	rx_pkt->size = MCTP_PACKET_SIZE(len);
+	rx_pkt->start = 0;
+	rx_pkt->end = MCTP_PACKET_SIZE(len);
+	rx_pkt->mctp_hdr_off = 0;
+	rx_pkt->next = NULL;
+	memcpy(rx_pkt->data, &pktbuf->hdr, sizeof(pktbuf->hdr));
+	memcpy(rx_pkt->data + sizeof(pktbuf->hdr), pktbuf->payload, alloc_size);
+
+	mctp_bus_rx((struct mctp_binding *)binding, rx_pkt);
+}
+
+static void receive_one_fragment(struct mctp_binding_test *binding,
+				uint8_t *payload, size_t fragment_size,
+				uint8_t flags_seq_tag, struct pktbuf *pktbuf)
+{
+	pktbuf->hdr.flags_seq_tag = flags_seq_tag;
+	pktbuf->payload = payload;
+	receive_ptkbuf(binding, pktbuf, fragment_size);
+}
+
+static void receive_two_fragment_message(struct mctp_binding_test *binding,
+					uint8_t *payload,
+					size_t fragment1_size,
+					size_t fragment2_size,
+					struct pktbuf *pktbuf)
+{
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	uint8_t flags_seq_tag;
+
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, payload, fragment1_size, flags_seq_tag,
+			pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, payload + fragment1_size, fragment2_size,
+			flags_seq_tag, pktbuf);
+}
+
+static void mctp_core_test_simple_rx()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	uint8_t test_payload[2 * MCTP_BTU];
+	struct pktbuf pktbuf;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	/* Receive 2 fragments of equal size */
+	receive_two_fragment_message(binding, test_payload, MCTP_BTU, MCTP_BTU,
+			&pktbuf);
+
+	assert(test_param.seen);
+	assert(test_param.message_size == 2 * MCTP_BTU);
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+static void mctp_core_test_receive_equal_length_fragments()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	static uint8_t test_payload[MAX_PAYLOAD_SIZE];
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	struct pktbuf pktbuf;
+	uint8_t flags_seq_tag;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	/* Receive 3 fragments, each of size MCTP_BTU */
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload, MCTP_BTU, flags_seq_tag,
+			&pktbuf);
+
+	flags_seq_tag = (get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + MCTP_BTU, MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + (2 * MCTP_BTU), MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	assert(test_param.seen);
+	assert(test_param.message_size == 3 * MCTP_BTU);
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+static void mctp_core_test_receive_unexpected_smaller_middle_fragment()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	static uint8_t test_payload[MAX_PAYLOAD_SIZE];
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	struct pktbuf pktbuf;
+	uint8_t flags_seq_tag;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	/* Middle fragment with size MCTP_BTU - 1 */
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload, MCTP_BTU, flags_seq_tag,
+			&pktbuf);
+
+	flags_seq_tag = (get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + MCTP_BTU, MCTP_BTU - 1,
+			flags_seq_tag, &pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + (2 * MCTP_BTU), MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	assert(!test_param.seen);
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+static void mctp_core_test_receive_unexpected_bigger_middle_fragment()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	static uint8_t test_payload[MAX_PAYLOAD_SIZE];
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	struct pktbuf pktbuf;
+	uint8_t flags_seq_tag;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	/* Middle fragment with size MCTP_BTU + 1 */
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload, MCTP_BTU, flags_seq_tag,
+			&pktbuf);
+
+	flags_seq_tag = (get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + MCTP_BTU, MCTP_BTU + 1,
+			flags_seq_tag, &pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + (2 * MCTP_BTU), MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	assert(!test_param.seen);
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+static void mctp_core_test_receive_smaller_end_fragment()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	static uint8_t test_payload[MAX_PAYLOAD_SIZE];
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	uint8_t end_frag_size = MCTP_BTU - 10;
+	struct pktbuf pktbuf;
+	uint8_t flags_seq_tag;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload, MCTP_BTU, flags_seq_tag,
+			&pktbuf);
+
+	flags_seq_tag = (get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + MCTP_BTU, MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + (2 * MCTP_BTU),
+			end_frag_size, flags_seq_tag, &pktbuf);
+
+	assert(test_param.seen);
+	assert(test_param.message_size ==
+			(size_t)(2 * MCTP_BTU + end_frag_size));
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+static void mctp_core_test_receive_bigger_end_fragment()
+{
+	struct mctp *mctp = NULL;
+	struct mctp_binding_test *binding = NULL;
+	struct test_params test_param;
+	static uint8_t test_payload[MAX_PAYLOAD_SIZE];
+	uint8_t tag = MCTP_HDR_FLAG_TO | get_tag();
+	uint8_t end_frag_size = MCTP_BTU + 10;
+	struct pktbuf pktbuf;
+	uint8_t flags_seq_tag;
+
+	memset(test_payload, 0, sizeof(test_payload));
+	test_param.seen = false;
+	test_param.message_size = 0;
+	mctp_test_stack_init(&mctp, &binding, TEST_DEST_EID);
+	mctp_set_rx_all(mctp, rx_message, &test_param);
+	memset(&pktbuf, 0, sizeof(pktbuf));
+	pktbuf.hdr.dest = TEST_DEST_EID;
+	pktbuf.hdr.src = TEST_SRC_EID;
+
+	flags_seq_tag = MCTP_HDR_FLAG_SOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload, MCTP_BTU, flags_seq_tag,
+			&pktbuf);
+
+	flags_seq_tag = (get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + MCTP_BTU, MCTP_BTU,
+			flags_seq_tag, &pktbuf);
+
+	flags_seq_tag = MCTP_HDR_FLAG_EOM |
+			(get_sequence() << MCTP_HDR_SEQ_SHIFT) | tag;
+	receive_one_fragment(binding, test_payload + (2 * MCTP_BTU),
+			end_frag_size, flags_seq_tag, &pktbuf);
+
+	assert(!test_param.seen);
+
+	mctp_binding_test_destroy(binding);
+	mctp_destroy(mctp);
+}
+
+/* clang-format off */
+#define TEST_CASE(test) { #test, test }
+static const struct {
+	const char *name;
+	void (*test)(void);
+} mctp_core_tests[] = {
+	TEST_CASE(mctp_core_test_simple_rx),
+	TEST_CASE(mctp_core_test_receive_equal_length_fragments),
+	TEST_CASE(mctp_core_test_receive_unexpected_smaller_middle_fragment),
+	TEST_CASE(mctp_core_test_receive_unexpected_bigger_middle_fragment),
+	TEST_CASE(mctp_core_test_receive_smaller_end_fragment),
+	TEST_CASE(mctp_core_test_receive_bigger_end_fragment),
+};
+/* clang-format on */
+
+#ifndef BUILD_ASSERT
+#define BUILD_ASSERT(x)                                                        \
+	do {                                                                   \
+		(void)sizeof(char[0 - (!(x))]);                                \
+	} while (0)
+#endif
+
+int main(void)
+{
+	uint8_t i;
+
+	mctp_set_log_stdio(MCTP_LOG_DEBUG);
+
+	BUILD_ASSERT(ARRAY_SIZE(mctp_core_tests) < SIZE_MAX);
+	for (i = 0; i < ARRAY_SIZE(mctp_core_tests); i++) {
+		mctp_prlog(MCTP_LOG_DEBUG, "begin: %s",
+				mctp_core_tests[i].name);
+		mctp_core_tests[i].test();
+		mctp_prlog(MCTP_LOG_DEBUG, "end: %s\n",
+				mctp_core_tests[i].name);
+	}
+
+	return 0;
+}