transport: Generalise the pldm_transport_recv_msg() API

Currently pldm_transport_recv_msg() only works for requesters as the TID
param is an input. Responders need the source TID of the message
received so they know where to send the response.

The TID was being used to look up the EID mapped to the TID and failing
the function call if it didn't match. This check doesn't need to happen
at this level, and can be added in at the requester API level if
required.

Make the TID param an output, and use the EID of the message to lookup
the TID.

Change-Id: I671dbfe2d94a9ad8d77ea0ef150f1c744f928c53
Signed-off-by: Rashmica Gupta <rashmica@linux.ibm.com>
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4872744..bd65e3d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,6 +25,7 @@
 4. pdr: Stabilise pldm_entity_association_tree_find_with_locality()
 5. pdr: Stabilize pldm_entity_node_get_remote_container_id()
 6. transport: af-mctp: Assign out-params on success in \*\_recv()
+7. transport: Generalise the pldm_transport_recv_msg() API
 
 ### Removed
 
diff --git a/include/libpldm/transport.h b/include/libpldm/transport.h
index 6554c4f..660ff02 100644
--- a/include/libpldm/transport.h
+++ b/include/libpldm/transport.h
@@ -63,7 +63,7 @@
  * 	up.
  *
  * @param[in] ctx - pldm transport instance
- * @param[in] tid - destination PLDM TID
+ * @param[out] tid - source PLDM TID
  * @param[out] pldm_msg - *pldm_msg will point to the received PLDM msg if
  * 	       return code is PLDM_REQUESTER_SUCCESS; otherwise, NULL. On
  * 	       success this function allocates memory, caller to
@@ -77,7 +77,7 @@
  *
  */
 pldm_requester_rc_t pldm_transport_recv_msg(struct pldm_transport *transport,
-					    pldm_tid_t tid, void **pldm_msg,
+					    pldm_tid_t *tid, void **pldm_msg,
 					    size_t *msg_len);
 
 /**
diff --git a/src/requester/pldm.c b/src/requester/pldm.c
index fc7e782..5297947 100644
--- a/src/requester/pldm.c
+++ b/src/requester/pldm.c
@@ -99,8 +99,35 @@
 				  uint8_t **pldm_resp_msg, size_t *resp_msg_len)
 {
 	pldm_requester_rc_t rc = 0;
-	PLDM_REQ_FN(eid, mctp_fd, pldm_transport_recv_msg, rc,
-		    (void **)pldm_resp_msg, resp_msg_len);
+
+	struct pldm_transport_mctp_demux *demux;
+	bool using_open_transport = false;
+	pldm_tid_t tid = eid;
+	struct pldm_transport *ctx;
+	/* The fd can be for a socket we opened or one the consumer
+	 * opened. */
+	if (open_transport &&
+	    mctp_fd ==
+		    pldm_transport_mctp_demux_get_socket_fd(open_transport)) {
+		using_open_transport = true;
+		demux = open_transport;
+	} else {
+		demux = pldm_transport_mctp_demux_init_with_fd(mctp_fd);
+		if (!demux) {
+			rc = PLDM_REQUESTER_OPEN_FAIL;
+			goto transport_out;
+		}
+	}
+	ctx = pldm_transport_mctp_demux_core(demux);
+	rc = pldm_transport_mctp_demux_map_tid(demux, tid, eid);
+	if (rc) {
+		rc = PLDM_REQUESTER_OPEN_FAIL;
+		goto transport_out;
+	}
+	/* TODO this is the only change, can we work this into the macro? */
+	rc = pldm_transport_recv_msg(ctx, &tid, (void **)pldm_resp_msg,
+				     resp_msg_len);
+
 	struct pldm_msg_hdr *hdr = (struct pldm_msg_hdr *)(*pldm_resp_msg);
 	if (rc != PLDM_REQUESTER_SUCCESS) {
 		return rc;
@@ -116,6 +143,12 @@
 		*pldm_resp_msg = NULL;
 		return PLDM_REQUESTER_RESP_MSG_TOO_SMALL;
 	}
+
+transport_out:
+	if (!using_open_transport) {
+		pldm_transport_mctp_demux_destroy(demux);
+	}
+
 	return rc;
 }
 
diff --git a/src/transport/af-mctp.c b/src/transport/af-mctp.c
index 08d8322..d3f5882 100644
--- a/src/transport/af-mctp.c
+++ b/src/transport/af-mctp.c
@@ -60,6 +60,16 @@
 	return -1;
 }
 
+static int pldm_transport_af_mctp_get_tid(struct pldm_transport_af_mctp *ctx,
+					  mctp_eid_t eid, pldm_tid_t *tid)
+{
+	if (ctx->tid_eid_map[eid] != 0) {
+		*tid = ctx->tid_eid_map[eid];
+		return 0;
+	}
+	return -1;
+}
+
 LIBPLDM_ABI_TESTING
 int pldm_transport_af_mctp_map_tid(struct pldm_transport_af_mctp *ctx,
 				   pldm_tid_t tid, mctp_eid_t eid)
@@ -80,21 +90,19 @@
 }
 
 static pldm_requester_rc_t pldm_transport_af_mctp_recv(struct pldm_transport *t,
-						       pldm_tid_t tid,
+						       pldm_tid_t *tid,
 						       void **pldm_msg,
 						       size_t *msg_len)
 {
 	struct pldm_transport_af_mctp *af_mctp = transport_to_af_mctp(t);
+	struct sockaddr_mctp addr = { 0 };
+	socklen_t addrlen = sizeof(addr);
+	pldm_requester_rc_t res;
 	mctp_eid_t eid = 0;
 	ssize_t length;
 	void *msg;
 	int rc;
 
-	rc = pldm_transport_af_mctp_get_eid(af_mctp, tid, &eid);
-	if (rc) {
-		return PLDM_REQUESTER_RECV_FAIL;
-	}
-
 	length = recv(af_mctp->socket, NULL, 0, MSG_PEEK | MSG_TRUNC);
 	if (length <= 0) {
 		return PLDM_REQUESTER_RECV_FAIL;
@@ -105,16 +113,29 @@
 		return PLDM_REQUESTER_RECV_FAIL;
 	}
 
-	length = recv(af_mctp->socket, msg, length, MSG_TRUNC);
+	length = recvfrom(af_mctp->socket, msg, length, MSG_TRUNC,
+			  (struct sockaddr *)&addr, &addrlen);
 	if (length < (ssize_t)sizeof(struct pldm_msg_hdr)) {
-		free(msg);
-		return PLDM_REQUESTER_INVALID_RECV_LEN;
+		res = PLDM_REQUESTER_INVALID_RECV_LEN;
+		goto cleanup_msg;
+	}
+
+	eid = addr.smctp_addr.s_addr;
+	rc = pldm_transport_af_mctp_get_tid(af_mctp, eid, tid);
+	if (rc) {
+		res = PLDM_REQUESTER_RECV_FAIL;
+		goto cleanup_msg;
 	}
 
 	*pldm_msg = msg;
 	*msg_len = length;
 
 	return PLDM_REQUESTER_SUCCESS;
+
+cleanup_msg:
+	free(msg);
+
+	return res;
 }
 
 static pldm_requester_rc_t pldm_transport_af_mctp_send(struct pldm_transport *t,
diff --git a/src/transport/mctp-demux.c b/src/transport/mctp-demux.c
index dc5e2d2..240f931 100644
--- a/src/transport/mctp-demux.c
+++ b/src/transport/mctp-demux.c
@@ -91,6 +91,18 @@
 	return -1;
 }
 
+static int
+pldm_transport_mctp_demux_get_tid(struct pldm_transport_mctp_demux *ctx,
+				  mctp_eid_t eid, pldm_tid_t *tid)
+{
+	/* mapping exists */
+	if (ctx->tid_eid_map[eid] != 0) {
+		*tid = ctx->tid_eid_map[eid];
+		return 0;
+	}
+	return -1;
+}
+
 LIBPLDM_ABI_TESTING
 int pldm_transport_mctp_demux_map_tid(struct pldm_transport_mctp_demux *ctx,
 				      pldm_tid_t tid, mctp_eid_t eid)
@@ -111,7 +123,7 @@
 }
 
 static pldm_requester_rc_t
-pldm_transport_mctp_demux_recv(struct pldm_transport *t, pldm_tid_t tid,
+pldm_transport_mctp_demux_recv(struct pldm_transport *t, pldm_tid_t *tid,
 			       void **pldm_msg, size_t *msg_len)
 {
 	struct pldm_transport_mctp_demux *demux = transport_to_demux(t);
@@ -128,11 +140,6 @@
 	uint8_t *buf;
 	int rc;
 
-	rc = pldm_transport_mctp_demux_get_eid(demux, tid, &eid);
-	if (rc) {
-		return PLDM_REQUESTER_RECV_FAIL;
-	}
-
 	min_len = sizeof(eid) + sizeof(mctp_msg_type) +
 		  sizeof(struct pldm_msg_hdr);
 	length = recv(demux->socket, NULL, 0, MSG_PEEK | MSG_TRUNC);
@@ -167,11 +174,18 @@
 		goto cleanup_buf;
 	}
 
-	if ((mctp_prefix[0] != eid) || (mctp_prefix[1] != mctp_msg_type)) {
+	if (mctp_prefix[1] != mctp_msg_type) {
 		res = PLDM_REQUESTER_NOT_PLDM_MSG;
 		goto cleanup_buf;
 	}
 
+	eid = mctp_prefix[0];
+	rc = pldm_transport_mctp_demux_get_tid(demux, eid, tid);
+	if (rc) {
+		res = PLDM_REQUESTER_RECV_FAIL;
+		goto cleanup_buf;
+	}
+
 	*pldm_msg = buf;
 	*msg_len = pldm_len;
 
diff --git a/src/transport/test.c b/src/transport/test.c
index dad8690..730d8e6 100644
--- a/src/transport/test.c
+++ b/src/transport/test.c
@@ -98,7 +98,7 @@
 #endif
 
 static pldm_requester_rc_t pldm_transport_test_recv(struct pldm_transport *ctx,
-						    pldm_tid_t tid,
+						    pldm_tid_t *tid,
 						    void **pldm_resp_msg,
 						    size_t *resp_msg_len)
 {
@@ -106,8 +106,6 @@
 	const struct pldm_transport_test_descriptor *desc;
 	void *msg;
 
-	(void)tid;
-
 	if (test->cursor >= test->count) {
 		return PLDM_REQUESTER_RECV_FAIL;
 	}
@@ -126,6 +124,7 @@
 	memcpy(msg, desc->recv_msg.msg, desc->recv_msg.len);
 	*pldm_resp_msg = msg;
 	*resp_msg_len = desc->recv_msg.len;
+	*tid = desc->recv_msg.src;
 
 	test->cursor++;
 
diff --git a/src/transport/transport.c b/src/transport/transport.c
index c78ec89..e800ef9 100644
--- a/src/transport/transport.c
+++ b/src/transport/transport.c
@@ -76,7 +76,7 @@
 
 LIBPLDM_ABI_TESTING
 pldm_requester_rc_t pldm_transport_recv_msg(struct pldm_transport *transport,
-					    pldm_tid_t tid, void **pldm_msg,
+					    pldm_tid_t *tid, void **pldm_msg,
 					    size_t *msg_len)
 {
 	if (!transport || !msg_len) {
@@ -172,7 +172,8 @@
 	for (cnt = 0; cnt <= (PLDM_INSTANCE_MAX + 1) * PLDM_MAX_TIDS &&
 		      pldm_transport_poll(transport, 0) == 1;
 	     cnt++) {
-		rc = pldm_transport_recv_msg(transport, tid, pldm_resp_msg,
+		pldm_tid_t l_tid;
+		rc = pldm_transport_recv_msg(transport, &l_tid, pldm_resp_msg,
 					     resp_msg_len);
 		if (rc == PLDM_REQUESTER_SUCCESS) {
 			/* This isn't the message we wanted */
@@ -207,11 +208,13 @@
 			break;
 		}
 
-		rc = pldm_transport_recv_msg(transport, tid, pldm_resp_msg,
+		pldm_tid_t src_tid;
+		rc = pldm_transport_recv_msg(transport, &src_tid, pldm_resp_msg,
 					     resp_msg_len);
 		if (rc == PLDM_REQUESTER_SUCCESS) {
 			const struct pldm_msg_hdr *resp_hdr = *pldm_resp_msg;
-			if (req_hdr->instance_id == resp_hdr->instance_id) {
+			if ((src_tid == tid) &&
+			    (req_hdr->instance_id == resp_hdr->instance_id)) {
 				return rc;
 			}
 
diff --git a/src/transport/transport.h b/src/transport/transport.h
index dd18c02..f052785 100644
--- a/src/transport/transport.h
+++ b/src/transport/transport.h
@@ -18,7 +18,7 @@
 	const char *name;
 	uint8_t version;
 	pldm_requester_rc_t (*recv)(struct pldm_transport *transport,
-				    pldm_tid_t tid, void **pldm_msg,
+				    pldm_tid_t *tid, void **pldm_resp_msg,
 				    size_t *msg_len);
 	pldm_requester_rc_t (*send)(struct pldm_transport *transport,
 				    pldm_tid_t tid, const void *pldm_msg,
diff --git a/tests/transport.cpp b/tests/transport.cpp
index 83f0a34..0838751 100644
--- a/tests/transport.cpp
+++ b/tests/transport.cpp
@@ -47,12 +47,13 @@
 TEST(Transport, recv_one)
 {
     uint8_t msg[] = {0x01, 0x00, 0x01, 0x00};
+    const pldm_tid_t src_tid = 1;
     const struct pldm_transport_test_descriptor seq[] = {
         {
             .type = PLDM_TRANSPORT_TEST_ELEMENT_MSG_RECV,
             .recv_msg =
                 {
-                    .src = 1,
+                    .src = src_tid,
                     .msg = msg,
                     .len = sizeof(msg),
                 },
@@ -63,13 +64,15 @@
     void* recvd;
     size_t len;
     int rc;
+    pldm_tid_t tid;
 
     EXPECT_EQ(pldm_transport_test_init(&test, seq, ARRAY_SIZE(seq)), 0);
     ctx = pldm_transport_test_core(test);
-    rc = pldm_transport_recv_msg(ctx, 1, &recvd, &len);
+    rc = pldm_transport_recv_msg(ctx, &tid, &recvd, &len);
     EXPECT_EQ(rc, PLDM_REQUESTER_SUCCESS);
     EXPECT_EQ(len, sizeof(msg));
     EXPECT_EQ(memcmp(recvd, msg, len), 0);
+    EXPECT_EQ(tid, src_tid);
     free(recvd);
     pldm_transport_test_destroy(test);
 }