requester: Add coroutine API
Added coroutine API to Handler using stdexec to help send and receive
messages in one function call.
For example, to get all PDRs from terminus needs multiple getPDR
commands. Except first getPDR commands, the following getPDR command
request data(e.g. dataTransferHandle) needs the previous getPDR command
response data(e.g. nextDataTransferHandle). By using C++ co_await, the
code to get all PDRs can straightforward be implemented by a while loop.
Pseudo code example:
do
{
auto rc = co_await getPDR(dataTransferHndl, ...);
// update dataTransferHndl for next getPDR command
dataTransferHndl = nextDataTransferHndl;
} while(...);
Signed-off-by: Gilbert Chen <gilbert.chen@arm.com>
Signed-off-by: Khang Nguyen <khangng@amperecomputing.com>
Signed-off-by: Thu Nguyen <thu@os.amperecomputing.com>
Change-Id: I7b47d15ac15f6ae661ec94dca6a281844b939a44
diff --git a/requester/handler.hpp b/requester/handler.hpp
index 3b1eca5..cefeb78 100644
--- a/requester/handler.hpp
+++ b/requester/handler.hpp
@@ -9,6 +9,7 @@
#include <sys/socket.h>
#include <phosphor-logging/lg2.hpp>
+#include <sdbusplus/async.hpp>
#include <sdbusplus/timer.hpp>
#include <sdeventplus/event.hpp>
#include <sdeventplus/source/event.hpp>
@@ -282,6 +283,72 @@
return PLDM_SUCCESS;
}
+ /** @brief Unregister a PLDM request message
+ *
+ * @param[in] eid - endpoint ID of the remote MCTP endpoint
+ * @param[in] instanceId - instance ID to match request and response
+ * @param[in] type - PLDM type
+ * @param[in] command - PLDM command
+ *
+ * @return return PLDM_SUCCESS on success and PLDM_ERROR otherwise
+ */
+ int unregisterRequest(mctp_eid_t eid, uint8_t instanceId, uint8_t type,
+ uint8_t command)
+ {
+ RequestKey key{eid, instanceId, type, command};
+
+ /* handlers only contain key when the message is already sent */
+ if (handlers.contains(key))
+ {
+ auto& [request, responseHandler, timerInstance] = handlers[key];
+ request->stop();
+ auto rc = timerInstance->stop();
+ if (rc)
+ {
+ error(
+ "Failed to stop the instance ID expiry timer, response code '{RC}'",
+ "RC", static_cast<int>(rc));
+ }
+
+ instanceIdDb.free(key.eid, key.instanceId);
+ handlers.erase(key);
+ endpointMessageQueues[eid]->activeRequest = false;
+ /* try to send new request if the endpoint is free */
+ pollEndpointQueue(eid);
+
+ return PLDM_SUCCESS;
+ }
+ else
+ {
+ if (!endpointMessageQueues.contains(eid))
+ {
+ error(
+ "Can't find request for EID '{EID}' is using InstanceID '{INSTANCEID}' in Endpoint message Queue",
+ "EID", (unsigned)eid, "INSTANCEID", (unsigned)instanceId);
+ return PLDM_ERROR;
+ }
+ auto requestMsg = endpointMessageQueues[eid]->requestQueue;
+ /* Find the registered request in the requestQueue */
+ for (auto it = requestMsg.begin(); it != requestMsg.end();)
+ {
+ auto msg = *it;
+ if (msg->key == key)
+ {
+ // erase and get the next valid iterator
+ it = endpointMessageQueues[eid]->requestQueue.erase(it);
+ instanceIdDb.free(key.eid, key.instanceId);
+ return PLDM_SUCCESS;
+ }
+ else
+ {
+ ++it; // increment iterator only if not erasing
+ }
+ }
+ }
+
+ return PLDM_ERROR;
+ }
+
/** @brief Handle PLDM response message
*
* @param[in] eid - endpoint ID of the remote MCTP endpoint
@@ -325,6 +392,14 @@
}
}
+ /** @brief Wrap registerRequest with coroutine API.
+ *
+ * @return A tuple of [return_code, pldm::Response].
+ * pldm::Response is empty on non-zero return_code.
+ * Otherwise, filled with pldm_msg* content.
+ */
+ stdexec::sender auto sendRecvMsg(mctp_eid_t eid, pldm::Request&& request);
+
private:
PldmTransport* pldmTransport; //!< PLDM transport object
sdeventplus::Event& event; //!< reference to PLDM daemon's main event loop
@@ -374,6 +449,208 @@
}
};
+/** @class SendRecvMsgOperation
+ *
+ * Represents the state and logic for a single send/receive message operation
+ *
+ * @tparam RequestInterface - Request class type
+ * @tparam stdexec::receiver - Execute receiver
+ */
+template <class RequestInterface, stdexec::receiver R>
+struct SendRecvMsgOperation
+{
+ SendRecvMsgOperation() = delete;
+
+ explicit SendRecvMsgOperation(Handler<RequestInterface>& handler,
+ mctp_eid_t eid, pldm::Request&& request,
+ R&& r) :
+ handler(handler),
+ request(std::move(request)), receiver(std::move(r))
+ {
+ auto requestMsg =
+ reinterpret_cast<const pldm_msg*>(this->request.data());
+ requestKey = RequestKey{
+ eid,
+ requestMsg->hdr.instance_id,
+ requestMsg->hdr.type,
+ requestMsg->hdr.command,
+ };
+ response = nullptr;
+ respMsgLen = 0;
+ }
+
+ /** @brief Checks if the operation has been requested to stop.
+ * If so, it sets the state to stopped.Registers the request with
+ * the handler. If registration fails, sets an error on the
+ * receiver. If stopping is possible, sets up a stop callback.
+ *
+ * @param[in] op - operation request
+ *
+ * @return Execute errors
+ */
+ friend void tag_invoke(stdexec::start_t, SendRecvMsgOperation& op) noexcept
+ {
+ auto stopToken = stdexec::get_stop_token(stdexec::get_env(op.receiver));
+
+ // operation already cancelled
+ if (stopToken.stop_requested())
+ {
+ return stdexec::set_stopped(std::move(op.receiver));
+ }
+
+ using namespace std::placeholders;
+ auto rc = op.handler.registerRequest(
+ op.requestKey.eid, op.requestKey.instanceId, op.requestKey.type,
+ op.requestKey.command, std::move(op.request),
+ std::bind(&SendRecvMsgOperation::onComplete, &op, _1, _2, _3));
+ if (rc)
+ {
+ return stdexec::set_error(std::move(op.receiver), rc);
+ }
+
+ if (stopToken.stop_possible())
+ {
+ op.stopCallback.emplace(
+ std::move(stopToken),
+ std::bind(&SendRecvMsgOperation::onStop, &op));
+ }
+ }
+
+ /** @brief Unregisters the request and sets the state to stopped on the
+ * receiver.
+ */
+ void onStop()
+ {
+ handler.unregisterRequest(requestKey.eid, requestKey.instanceId,
+ requestKey.type, requestKey.command);
+ return stdexec::set_stopped(std::move(receiver));
+ }
+
+ /** @brief This function resets the stop callback. Validates the response
+ * and sets either an error or a value on the receiver.
+ *
+ * @param[in] eid - endpoint ID of the remote MCTP endpoint
+ * @param[in] response - PLDM response message
+ * @param[in] respMsgLen - length of the response message
+ *
+ * @return PLDM completion code
+ */
+ void onComplete(mctp_eid_t eid, const pldm_msg* response, size_t respMsgLen)
+ {
+ stopCallback.reset();
+ assert(eid == this->requestKey.eid);
+ if (!response || !respMsgLen)
+ {
+ return stdexec::set_error(std::move(receiver),
+ static_cast<int>(PLDM_ERROR));
+ }
+ else
+ {
+ return stdexec::set_value(std::move(receiver), response,
+ respMsgLen);
+ }
+ }
+
+ private:
+ /** @brief Reference to a Handler object that manages the request/response
+ * logic.
+ */
+ requester::Handler<RequestInterface>& handler;
+
+ /** @brief Stores information about the request such as eid, instanceId,
+ * type, and command.
+ */
+ RequestKey requestKey;
+
+ /** @brief The request message to be sent.
+ */
+ pldm::Request request;
+
+ /** @brief The response message for the sent request message.
+ */
+ const pldm_msg* response;
+
+ /** @brief The length of response message for the sent request message.
+ */
+ size_t respMsgLen;
+
+ /** @brief The receiver to be notified with the result of the operation.
+ */
+ R receiver;
+
+ /** @brief An optional callback that handles stopping the operation if
+ * requested.
+ */
+ std::optional<typename stdexec::stop_token_of_t<
+ stdexec::env_of_t<R>>::template callback_type<std::function<void()>>>
+ stopCallback = std::nullopt;
+};
+
+/** @class SendRecvMsgSender
+ *
+ * Represents the single message sender
+ *
+ * @tparam RequestInterface - Request class type
+ */
+template <class RequestInterface>
+struct SendRecvMsgSender
+{
+ using is_sender = void;
+
+ SendRecvMsgSender() = delete;
+
+ explicit SendRecvMsgSender(requester::Handler<RequestInterface>& handler,
+ mctp_eid_t eid, pldm::Request&& request) :
+ handler(handler),
+ eid(eid), request(std::move(request))
+ {}
+
+ friend auto tag_invoke(stdexec::get_completion_signatures_t,
+ const SendRecvMsgSender&, auto)
+ -> stdexec::completion_signatures<
+ stdexec::set_value_t(const pldm_msg*, size_t),
+ stdexec::set_error_t(int), stdexec::set_stopped_t()>;
+
+ /** @brief Execute the sending the request message */
+ template <stdexec::receiver R>
+ friend auto tag_invoke(stdexec::connect_t, SendRecvMsgSender&& self, R r)
+ {
+ return SendRecvMsgOperation<RequestInterface, R>(
+ self.handler, self.eid, std::move(self.request), std::move(r));
+ }
+
+ private:
+ /** @brief Reference to a Handler object that manages the request/response
+ * logic.
+ */
+ requester::Handler<RequestInterface>& handler;
+
+ /** @brief MCTP Endpoint ID of request message */
+ mctp_eid_t eid;
+
+ /** @brief Request message */
+ pldm::Request request;
+};
+
+/** @brief This function handles sending the request message and responses the
+ * response message for the caller.
+ *
+ * @param[in] eid - endpoint ID of the remote MCTP endpoint
+ * @param[in] request - PLDM request message
+ *
+ * @return The reponse message and response message length.
+ */
+template <class RequestInterface>
+stdexec::sender auto
+ Handler<RequestInterface>::sendRecvMsg(mctp_eid_t eid,
+ pldm::Request&& request)
+{
+ return SendRecvMsgSender(*this, eid, std::move(request)) |
+ stdexec::then([](const pldm_msg* responseMsg, size_t respMsgLen) {
+ return std::make_tuple(responseMsg, respMsgLen);
+ });
+}
+
} // namespace requester
} // namespace pldm
diff --git a/requester/test/handler_test.cpp b/requester/test/handler_test.cpp
index fcc417f..bff08ca 100644
--- a/requester/test/handler_test.cpp
+++ b/requester/test/handler_test.cpp
@@ -8,6 +8,8 @@
#include <libpldm/base.h>
#include <libpldm/transport.h>
+#include <sdbusplus/async.hpp>
+
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -150,3 +152,141 @@
EXPECT_EQ(validResponse, true);
EXPECT_EQ(callbackCount, 2);
}
+
+TEST_F(HandlerTest, singleRequestResponseScenarioUsingCoroutine)
+{
+ exec::async_scope scope;
+ Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
+ instanceIdDb, false, seconds(1),
+ 2, milliseconds(100));
+
+ auto instanceId = instanceIdDb.next(eid);
+ EXPECT_EQ(instanceId, 0);
+
+ scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
+ pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
+ const pldm_msg* responseMsg;
+ size_t responseLen;
+
+ auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
+ requestPtr->hdr.instance_id = instanceId;
+
+ try
+ {
+ std::tie(responseMsg, responseLen) =
+ co_await reqHandler.sendRecvMsg(eid, std::move(request));
+ }
+ catch (...)
+ {
+ std::rethrow_exception(std::current_exception());
+ }
+
+ EXPECT_NE(responseLen, 0);
+
+ this->pldmResponseCallBack(eid, responseMsg, responseLen);
+
+ EXPECT_EQ(validResponse, true);
+ }),
+ exec::default_task_context<void>());
+
+ pldm::Response mockResponse(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
+ auto mockResponsePtr =
+ reinterpret_cast<const pldm_msg*>(mockResponse.data());
+ reqHandler.handleResponse(eid, instanceId, 0, 0, mockResponsePtr,
+ mockResponse.size() - sizeof(pldm_msg_hdr));
+
+ stdexec::sync_wait(scope.on_empty());
+}
+
+TEST_F(HandlerTest, singleRequestCancellationScenarioUsingCoroutine)
+{
+ exec::async_scope scope;
+ Handler<NiceMock<MockRequest>> reqHandler(pldmTransport, event,
+ instanceIdDb, false, seconds(1),
+ 2, milliseconds(100));
+ auto instanceId = instanceIdDb.next(eid);
+ EXPECT_EQ(instanceId, 0);
+
+ bool stopped = false;
+
+ scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
+ pldm::Request request(sizeof(pldm_msg_hdr) + sizeof(uint8_t), 0);
+ pldm::Response response;
+
+ auto requestPtr = reinterpret_cast<pldm_msg*>(request.data());
+ requestPtr->hdr.instance_id = instanceId;
+
+ co_await reqHandler.sendRecvMsg(eid, std::move(request));
+
+ EXPECT_TRUE(false); // unreachable
+ }) | stdexec::upon_stopped([&] { stopped = true; }),
+ exec::default_task_context<void>());
+
+ scope.request_stop();
+
+ EXPECT_TRUE(stopped);
+
+ stdexec::sync_wait(scope.on_empty());
+}
+
+TEST_F(HandlerTest, asyncRequestResponseByCoroutine)
+{
+ struct _
+ {
+ static exec::task<uint8_t> getTIDTask(Handler<MockRequest>& handler,
+ mctp_eid_t eid,
+ uint8_t instanceId, uint8_t& tid)
+ {
+ pldm::Request request(sizeof(pldm_msg_hdr), 0);
+ auto requestMsg = reinterpret_cast<pldm_msg*>(request.data());
+ const pldm_msg* responseMsg;
+ size_t responseLen;
+
+ auto rc = encode_get_tid_req(instanceId, requestMsg);
+ EXPECT_EQ(rc, PLDM_SUCCESS);
+
+ std::tie(responseMsg, responseLen) =
+ co_await handler.sendRecvMsg(eid, std::move(request));
+ EXPECT_NE(responseLen, 0);
+
+ uint8_t cc = 0;
+ rc = decode_get_tid_resp(responseMsg, responseLen, &cc, &tid);
+ EXPECT_EQ(rc, PLDM_SUCCESS);
+
+ co_return cc;
+ }
+ };
+
+ exec::async_scope scope;
+ Handler<MockRequest> reqHandler(pldmTransport, event, instanceIdDb, false,
+ seconds(1), 2, milliseconds(100));
+ auto instanceId = instanceIdDb.next(eid);
+
+ uint8_t expectedTid = 1;
+
+ // Execute a coroutine to send getTID command. The coroutine is suspended
+ // until reqHandler.handleResponse() is received.
+ scope.spawn(stdexec::just() | stdexec::let_value([&] -> exec::task<void> {
+ uint8_t respTid = 0;
+
+ co_await _::getTIDTask(reqHandler, eid, instanceId, respTid);
+
+ EXPECT_EQ(expectedTid, respTid);
+ }),
+ exec::default_task_context<void>());
+
+ pldm::Response mockResponse(sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES,
+ 0);
+ auto mockResponseMsg = reinterpret_cast<pldm_msg*>(mockResponse.data());
+
+ // Compose response message of getTID command
+ encode_get_tid_resp(instanceId, PLDM_SUCCESS, expectedTid, mockResponseMsg);
+
+ // Send response back to resume getTID coroutine to update respTid by
+ // calling reqHandler.handleResponse() manually
+ reqHandler.handleResponse(eid, instanceId, PLDM_BASE, PLDM_GET_TID,
+ mockResponseMsg,
+ mockResponse.size() - sizeof(pldm_msg_hdr));
+
+ stdexec::sync_wait(scope.on_empty());
+}