transport: Generalise the pldm_transport_recv_msg() API

Currently pldm_transport_recv_msg() only works for requesters as the TID
param is an input. Responders need the source TID of the message
received so they know where to send the response.

The TID was being used to look up the EID mapped to the TID and failing
the function call if it didn't match. This check doesn't need to happen
at this level, and can be added in at the requester API level if
required.

Make the TID param an output, and use the EID of the message to lookup
the TID.

Change-Id: I671dbfe2d94a9ad8d77ea0ef150f1c744f928c53
Signed-off-by: Rashmica Gupta <rashmica@linux.ibm.com>
diff --git a/src/transport/af-mctp.c b/src/transport/af-mctp.c
index 08d8322..d3f5882 100644
--- a/src/transport/af-mctp.c
+++ b/src/transport/af-mctp.c
@@ -60,6 +60,16 @@
 	return -1;
 }
 
+static int pldm_transport_af_mctp_get_tid(struct pldm_transport_af_mctp *ctx,
+					  mctp_eid_t eid, pldm_tid_t *tid)
+{
+	if (ctx->tid_eid_map[eid] != 0) {
+		*tid = ctx->tid_eid_map[eid];
+		return 0;
+	}
+	return -1;
+}
+
 LIBPLDM_ABI_TESTING
 int pldm_transport_af_mctp_map_tid(struct pldm_transport_af_mctp *ctx,
 				   pldm_tid_t tid, mctp_eid_t eid)
@@ -80,21 +90,19 @@
 }
 
 static pldm_requester_rc_t pldm_transport_af_mctp_recv(struct pldm_transport *t,
-						       pldm_tid_t tid,
+						       pldm_tid_t *tid,
 						       void **pldm_msg,
 						       size_t *msg_len)
 {
 	struct pldm_transport_af_mctp *af_mctp = transport_to_af_mctp(t);
+	struct sockaddr_mctp addr = { 0 };
+	socklen_t addrlen = sizeof(addr);
+	pldm_requester_rc_t res;
 	mctp_eid_t eid = 0;
 	ssize_t length;
 	void *msg;
 	int rc;
 
-	rc = pldm_transport_af_mctp_get_eid(af_mctp, tid, &eid);
-	if (rc) {
-		return PLDM_REQUESTER_RECV_FAIL;
-	}
-
 	length = recv(af_mctp->socket, NULL, 0, MSG_PEEK | MSG_TRUNC);
 	if (length <= 0) {
 		return PLDM_REQUESTER_RECV_FAIL;
@@ -105,16 +113,29 @@
 		return PLDM_REQUESTER_RECV_FAIL;
 	}
 
-	length = recv(af_mctp->socket, msg, length, MSG_TRUNC);
+	length = recvfrom(af_mctp->socket, msg, length, MSG_TRUNC,
+			  (struct sockaddr *)&addr, &addrlen);
 	if (length < (ssize_t)sizeof(struct pldm_msg_hdr)) {
-		free(msg);
-		return PLDM_REQUESTER_INVALID_RECV_LEN;
+		res = PLDM_REQUESTER_INVALID_RECV_LEN;
+		goto cleanup_msg;
+	}
+
+	eid = addr.smctp_addr.s_addr;
+	rc = pldm_transport_af_mctp_get_tid(af_mctp, eid, tid);
+	if (rc) {
+		res = PLDM_REQUESTER_RECV_FAIL;
+		goto cleanup_msg;
 	}
 
 	*pldm_msg = msg;
 	*msg_len = length;
 
 	return PLDM_REQUESTER_SUCCESS;
+
+cleanup_msg:
+	free(msg);
+
+	return res;
 }
 
 static pldm_requester_rc_t pldm_transport_af_mctp_send(struct pldm_transport *t,
diff --git a/src/transport/mctp-demux.c b/src/transport/mctp-demux.c
index dc5e2d2..240f931 100644
--- a/src/transport/mctp-demux.c
+++ b/src/transport/mctp-demux.c
@@ -91,6 +91,18 @@
 	return -1;
 }
 
+static int
+pldm_transport_mctp_demux_get_tid(struct pldm_transport_mctp_demux *ctx,
+				  mctp_eid_t eid, pldm_tid_t *tid)
+{
+	/* mapping exists */
+	if (ctx->tid_eid_map[eid] != 0) {
+		*tid = ctx->tid_eid_map[eid];
+		return 0;
+	}
+	return -1;
+}
+
 LIBPLDM_ABI_TESTING
 int pldm_transport_mctp_demux_map_tid(struct pldm_transport_mctp_demux *ctx,
 				      pldm_tid_t tid, mctp_eid_t eid)
@@ -111,7 +123,7 @@
 }
 
 static pldm_requester_rc_t
-pldm_transport_mctp_demux_recv(struct pldm_transport *t, pldm_tid_t tid,
+pldm_transport_mctp_demux_recv(struct pldm_transport *t, pldm_tid_t *tid,
 			       void **pldm_msg, size_t *msg_len)
 {
 	struct pldm_transport_mctp_demux *demux = transport_to_demux(t);
@@ -128,11 +140,6 @@
 	uint8_t *buf;
 	int rc;
 
-	rc = pldm_transport_mctp_demux_get_eid(demux, tid, &eid);
-	if (rc) {
-		return PLDM_REQUESTER_RECV_FAIL;
-	}
-
 	min_len = sizeof(eid) + sizeof(mctp_msg_type) +
 		  sizeof(struct pldm_msg_hdr);
 	length = recv(demux->socket, NULL, 0, MSG_PEEK | MSG_TRUNC);
@@ -167,11 +174,18 @@
 		goto cleanup_buf;
 	}
 
-	if ((mctp_prefix[0] != eid) || (mctp_prefix[1] != mctp_msg_type)) {
+	if (mctp_prefix[1] != mctp_msg_type) {
 		res = PLDM_REQUESTER_NOT_PLDM_MSG;
 		goto cleanup_buf;
 	}
 
+	eid = mctp_prefix[0];
+	rc = pldm_transport_mctp_demux_get_tid(demux, eid, tid);
+	if (rc) {
+		res = PLDM_REQUESTER_RECV_FAIL;
+		goto cleanup_buf;
+	}
+
 	*pldm_msg = buf;
 	*msg_len = pldm_len;
 
diff --git a/src/transport/test.c b/src/transport/test.c
index dad8690..730d8e6 100644
--- a/src/transport/test.c
+++ b/src/transport/test.c
@@ -98,7 +98,7 @@
 #endif
 
 static pldm_requester_rc_t pldm_transport_test_recv(struct pldm_transport *ctx,
-						    pldm_tid_t tid,
+						    pldm_tid_t *tid,
 						    void **pldm_resp_msg,
 						    size_t *resp_msg_len)
 {
@@ -106,8 +106,6 @@
 	const struct pldm_transport_test_descriptor *desc;
 	void *msg;
 
-	(void)tid;
-
 	if (test->cursor >= test->count) {
 		return PLDM_REQUESTER_RECV_FAIL;
 	}
@@ -126,6 +124,7 @@
 	memcpy(msg, desc->recv_msg.msg, desc->recv_msg.len);
 	*pldm_resp_msg = msg;
 	*resp_msg_len = desc->recv_msg.len;
+	*tid = desc->recv_msg.src;
 
 	test->cursor++;
 
diff --git a/src/transport/transport.c b/src/transport/transport.c
index c78ec89..e800ef9 100644
--- a/src/transport/transport.c
+++ b/src/transport/transport.c
@@ -76,7 +76,7 @@
 
 LIBPLDM_ABI_TESTING
 pldm_requester_rc_t pldm_transport_recv_msg(struct pldm_transport *transport,
-					    pldm_tid_t tid, void **pldm_msg,
+					    pldm_tid_t *tid, void **pldm_msg,
 					    size_t *msg_len)
 {
 	if (!transport || !msg_len) {
@@ -172,7 +172,8 @@
 	for (cnt = 0; cnt <= (PLDM_INSTANCE_MAX + 1) * PLDM_MAX_TIDS &&
 		      pldm_transport_poll(transport, 0) == 1;
 	     cnt++) {
-		rc = pldm_transport_recv_msg(transport, tid, pldm_resp_msg,
+		pldm_tid_t l_tid;
+		rc = pldm_transport_recv_msg(transport, &l_tid, pldm_resp_msg,
 					     resp_msg_len);
 		if (rc == PLDM_REQUESTER_SUCCESS) {
 			/* This isn't the message we wanted */
@@ -207,11 +208,13 @@
 			break;
 		}
 
-		rc = pldm_transport_recv_msg(transport, tid, pldm_resp_msg,
+		pldm_tid_t src_tid;
+		rc = pldm_transport_recv_msg(transport, &src_tid, pldm_resp_msg,
 					     resp_msg_len);
 		if (rc == PLDM_REQUESTER_SUCCESS) {
 			const struct pldm_msg_hdr *resp_hdr = *pldm_resp_msg;
-			if (req_hdr->instance_id == resp_hdr->instance_id) {
+			if ((src_tid == tid) &&
+			    (req_hdr->instance_id == resp_hdr->instance_id)) {
 				return rc;
 			}
 
diff --git a/src/transport/transport.h b/src/transport/transport.h
index dd18c02..f052785 100644
--- a/src/transport/transport.h
+++ b/src/transport/transport.h
@@ -18,7 +18,7 @@
 	const char *name;
 	uint8_t version;
 	pldm_requester_rc_t (*recv)(struct pldm_transport *transport,
-				    pldm_tid_t tid, void **pldm_msg,
+				    pldm_tid_t *tid, void **pldm_resp_msg,
 				    size_t *msg_len);
 	pldm_requester_rc_t (*send)(struct pldm_transport *transport,
 				    pldm_tid_t tid, const void *pldm_msg,