nvidia-gpu: Fix up buffering in MctpRequester
This change does a lot, for better or worse
1. Change MctpRequester to hold both buffers for send and receive
2. This requires changing the callback structure, so the reach is far
3. Changes error reporting to be through std::error_code
4. Collapses the QueuingRequeuster and Requeuster to be MctpRequeuster
5. Doing 4 gets rid of a level indirection and an extra unordered_map
6. Adds proper iid support, which is made significantly easier by 4/5
7. Fixes issues around expiry timer's where we would cancel the timer
for a given request whenever a new packet would come in to be sent.
This could cause lockup if a packet truly did time out and an
interleaved packet finished sending. This moves each queue
to have its own timer.
This fixes an issue where we were receiving buffers in from clients
and then binding them to receive_calls without ensuring that they
are the correct message, thus when receive was called, it was called
with the last bound buffer to async_receive_from. This would cause a
number of issues, ranging from incorrect device discovery results
to core dumps as well as incorrect sensor readings.
This change moves the receive and send buffers to be owned by
the MctpRequester, and a non-owning view is provided via
callback to the client. All existing clients just decode in place
given that buffer.
Tested: loaded onto nvl32-obmc. Correct number of sensors showed up
and the readings were nominal
Change-Id: I67c843691ca79e9fcccfa16df6d611918f25f6ca
Signed-off-by: Marc Olberding <molberding@nvidia.com>
diff --git a/src/nvidia-gpu/MctpRequester.cpp b/src/nvidia-gpu/MctpRequester.cpp
index 765859e..f28371f 100644
--- a/src/nvidia-gpu/MctpRequester.cpp
+++ b/src/nvidia-gpu/MctpRequester.cpp
@@ -18,13 +18,17 @@
#include <boost/container/devector.hpp>
#include <phosphor-logging/lg2.hpp>
-#include <cerrno>
+#include <bit>
#include <cstddef>
#include <cstdint>
#include <cstring>
+#include <expected>
+#include <format>
#include <functional>
-#include <memory>
+#include <optional>
#include <span>
+#include <stdexcept>
+#include <system_error>
#include <utility>
using namespace std::literals;
@@ -32,131 +36,314 @@
namespace mctp
{
-Requester::Requester(boost::asio::io_context& ctx) :
- mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}),
- expiryTimer(ctx)
-{}
-
-void Requester::processRecvMsg(
- const std::span<const uint8_t> reqMsg, const std::span<uint8_t> respMsg,
- const boost::system::error_code& ec, const size_t /*length*/)
+static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
+ std::span<const uint8_t> buffer)
{
- const auto* respAddr =
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- reinterpret_cast<const struct sockaddr_mctp*>(recvEndPoint.data());
-
- uint8_t eid = respAddr->smctp_addr.s_addr;
-
- if (!completionCallbacks.contains(eid))
+ if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
{
- lg2::error(
- "MctpRequester failed to get the callback for the EID: {EID}",
- "EID", static_cast<int>(eid));
+ return nullptr;
+ }
+
+ return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
+ buffer.data());
+}
+
+static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
+{
+ const ocp::accelerator_management::BindingPciVid* header =
+ getHeaderFromBuffer(buffer);
+ if (header == nullptr)
+ {
+ return std::nullopt;
+ }
+ return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
+}
+
+static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
+{
+ const ocp::accelerator_management::BindingPciVid* header =
+ getHeaderFromBuffer(buffer);
+ if (header == nullptr)
+ {
+ return std::nullopt;
+ }
+ return header->instance_id & ocp::accelerator_management::requestBitMask;
+}
+
+MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
+ io{ctx},
+ mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
+{
+ startReceive();
+}
+
+void MctpRequester::startReceive()
+{
+ mctpSocket.async_receive_from(
+ boost::asio::buffer(buffer), recvEndPoint.endpoint,
+ std::bind_front(&MctpRequester::processRecvMsg, this));
+}
+
+void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
+ const size_t length)
+{
+ std::optional<uint8_t> expectedEid = recvEndPoint.eid();
+ std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
+
+ if (!expectedEid || !receivedMsgType)
+ {
+ // we were handed an endpoint that can't be treated as an MCTP endpoint
+ // This is probably a kernel bug...yell about it and rebind.
+ lg2::error("MctpRequester: invalid endpoint");
return;
}
- auto& callback = completionCallbacks.at(eid);
-
- if (respAddr->smctp_type != msgType)
+ if (*receivedMsgType != msgType)
{
- lg2::error("MctpRequester: Message type mismatch");
- callback(EPROTO);
+ // we received a message that this handler doesn't support
+ // drop it on the floor and rebind receive_from
+ lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
+ "MSG", *receivedMsgType);
return;
}
- expiryTimer.cancel();
+ uint8_t eid = *expectedEid;
if (ec)
{
lg2::error(
"MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
"EC", ec.value(), "ER", ec.message());
- callback(EIO);
+ handleResult(eid, static_cast<std::error_code>(ec), {});
return;
}
- if (respMsg.size() > sizeof(ocp::accelerator_management::BindingPciVid))
+ // if the received length was greater than our buffer, we would've truncated
+ // and gotten an error code in asio
+ std::span<const uint8_t> responseBuffer{buffer.data(), length};
+
+ std::optional<uint8_t> optionalIid = getIid(responseBuffer);
+ std::optional<bool> isRq = getRequestBit(responseBuffer);
+ if (!optionalIid || !isRq)
{
- const auto* reqHdr =
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
- reqMsg.data());
-
- uint8_t reqInstanceId = reqHdr->instance_id &
- ocp::accelerator_management::instanceIdBitMask;
- const auto* respHdr =
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- reinterpret_cast<const ocp::accelerator_management::BindingPciVid*>(
- respMsg.data());
-
- uint8_t respInstanceId = respHdr->instance_id &
- ocp::accelerator_management::instanceIdBitMask;
-
- if (reqInstanceId != respInstanceId)
- {
- lg2::error(
- "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
- "REQ", static_cast<int>(reqInstanceId), "RESP",
- static_cast<int>(respInstanceId));
- callback(EPROTO);
- return;
- }
+ // we received something from the device,
+ // but we aren't able to parse iid byte
+ // drop this packet on the floor
+ // and rely on the timer to notify the client
+ lg2::error("MctpRequester: Unable to decode message from eid {EID}",
+ "EID", eid);
+ return;
}
- callback(0);
+ if (isRq.value())
+ {
+ // we received a request from a downstream device.
+ // We don't currently support this, drop the packet
+ // on the floor and rebind receive, keep the timer running
+ return;
+ }
+
+ uint8_t iid = *optionalIid;
+
+ auto it = requestContextQueues.find(eid);
+ if (it == requestContextQueues.end())
+ {
+ // something very bad has happened here
+ // we've received a packet that is a response
+ // from a device we've never talked to
+ // do our best and rebind receive and keep the timer running
+ lg2::error("Unable to match request to response");
+ return;
+ }
+
+ if (iid != it->second.iid)
+ {
+ // we received an iid that doesn't match the one we sent
+ // rebind async_receive_from and drop this packet on the floor
+ lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
+ iid, "EID", eid, "E_IID", it->second.iid);
+ return;
+ }
+
+ handleResult(eid, std::error_code{}, responseBuffer);
}
-void Requester::handleSendMsgCompletion(
- uint8_t eid, const std::span<const uint8_t> reqMsg,
- std::span<uint8_t> respMsg, const boost::system::error_code& ec,
- size_t /* length */)
+void MctpRequester::handleSendMsgCompletion(
+ uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
{
- if (!completionCallbacks.contains(eid))
- {
- lg2::error(
- "MctpRequester failed to get the callback for the EID: {EID}",
- "EID", static_cast<int>(eid));
- return;
- }
-
- auto& callback = completionCallbacks.at(eid);
-
if (ec)
{
lg2::error(
"MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
"EC", ec.value(), "ER", ec.message());
- callback(EIO);
+ handleResult(eid, static_cast<std::error_code>(ec), {});
return;
}
+ auto it = requestContextQueues.find(eid);
+ if (it == requestContextQueues.end())
+ {
+ // something very bad has happened here,
+ // we've sent something to a device that we have
+ // no record of. yell loudly and bail
+ lg2::error(
+ "MctpRequester completed send for an EID that we have no record of");
+ return;
+ }
+
+ boost::asio::steady_timer& expiryTimer = it->second.timer;
expiryTimer.expires_after(2s);
expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
if (ec != boost::asio::error::operation_aborted)
{
- auto& callback = completionCallbacks.at(eid);
- callback(ETIME);
+ lg2::error("Operation timed out on eid {EID}", "EID", eid);
+ handleResult(eid, std::make_error_code(std::errc::timed_out), {});
}
});
-
- mctpSocket.async_receive_from(
- boost::asio::mutable_buffer(respMsg), recvEndPoint,
- std::bind_front(&Requester::processRecvMsg, this, reqMsg, respMsg));
}
-void Requester::sendRecvMsg(uint8_t eid, const std::span<const uint8_t> reqMsg,
- std::span<uint8_t> respMsg,
- std::move_only_function<void(int)> callback)
+void MctpRequester::sendRecvMsg(
+ uint8_t eid, std::span<const uint8_t> reqMsg,
+ std::move_only_function<void(const std::error_code&,
+ std::span<const uint8_t>)>
+ callback)
{
- if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
+ RequestContext reqCtx{reqMsg, std::move(callback)};
+
+ // try_emplace only affects the result if the key does not already exist
+ auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
+ (void)inserted;
+
+ auto& queue = it->second.queue;
+ queue.push_back(std::move(reqCtx));
+
+ if (queue.size() == 1)
{
- lg2::error("MctpRequester: Message too small");
- callback(EPROTO);
+ processQueue(eid);
+ }
+}
+
+static bool isFatalError(const std::error_code& ec)
+{
+ return ec &&
+ (ec != std::errc::timed_out && ec != std::errc::host_unreachable);
+}
+
+void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
+ std::span<const uint8_t> buffer)
+{
+ auto it = requestContextQueues.find(eid);
+ if (it == requestContextQueues.end())
+ {
+ lg2::error("We tried to a handle a result for an eid we don't have");
+
+ startReceive();
return;
}
- completionCallbacks[eid] = std::move(callback);
+ auto& queue = it->second.queue;
+ auto& reqCtx = queue.front();
+
+ it->second.timer.cancel();
+
+ reqCtx.callback(ec, buffer); // Call the original callback
+
+ if (isFatalError(ec))
+ {
+ // some errors are fatal, since these are datagrams,
+ // we won't get a receive path error message.
+ // and since this daemon services all nvidia iana commands
+ // for a given system, we should only restart the service if its
+ // unrecoverable, i.e. if we get error codes that the client
+ // can't reasonably deal with. If thats the cause, restart
+ // and hope that we can deal with it then.
+ // since we're fully async, the only reasonable way to bubble
+ // this issue up is to chuck an exception and let main deal with it.
+ // alternatively we could call cancel on the io_context, but there's
+ // not a great way to figure *what* happened.
+ throw std::runtime_error(std::format(
+ "eid {} encountered a fatal error: {}", eid, ec.message()));
+ }
+
+ startReceive();
+
+ queue.pop_front();
+
+ processQueue(eid);
+}
+
+std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
+{
+ auto it = requestContextQueues.find(eid);
+ if (it == requestContextQueues.end())
+ {
+ return std::nullopt;
+ }
+
+ uint8_t& iid = it->second.iid;
+ ++iid;
+ iid &= ocp::accelerator_management::instanceIdBitMask;
+ return iid;
+}
+
+static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
+ uint8_t iid)
+{
+ if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
+ {
+ return std::unexpected(
+ std::make_error_code(std::errc::invalid_argument));
+ }
+
+ if (iid > ocp::accelerator_management::instanceIdBitMask)
+ {
+ return std::unexpected(
+ std::make_error_code(std::errc::invalid_argument));
+ }
+
+ auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
+ buffer.data());
+
+ header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
+ header->instance_id |= iid;
+ return {};
+}
+
+void MctpRequester::processQueue(uint8_t eid)
+{
+ auto it = requestContextQueues.find(eid);
+ if (it == requestContextQueues.end())
+ {
+ lg2::error("We are attempting to process a queue that doesn't exist");
+ return;
+ }
+
+ auto& queue = it->second.queue;
+
+ if (queue.empty())
+ {
+ return;
+ }
+ auto& reqCtx = queue.front();
+
+ std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
+
+ std::optional<uint8_t> iid = getNextIid(eid);
+ if (!iid)
+ {
+ lg2::error("MctpRequester: Unable to get next iid");
+ handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
+ return;
+ }
+
+ std::expected<void, std::error_code> success = injectIid(req, *iid);
+ if (!success)
+ {
+ lg2::error("MctpRequester: unable to set iid");
+ handleResult(eid, success.error(), {});
+ return;
+ }
struct sockaddr_mctp addr{};
addr.smctp_family = AF_MCTP;
@@ -167,54 +354,8 @@
sendEndPoint = {&addr, sizeof(addr)};
mctpSocket.async_send_to(
- boost::asio::const_buffer(reqMsg), sendEndPoint,
- std::bind_front(&Requester::handleSendMsgCompletion, this, eid, reqMsg,
- respMsg));
-}
-
-void QueuingRequester::sendRecvMsg(uint8_t eid, std::span<const uint8_t> reqMsg,
- std::span<uint8_t> respMsg,
- std::move_only_function<void(int)> callback)
-{
- auto reqCtx =
- std::make_unique<RequestContext>(reqMsg, respMsg, std::move(callback));
-
- // Add request to queue
- auto& queue = requestContextQueues[eid];
- queue.push_back(std::move(reqCtx));
-
- if (queue.size() == 1)
- {
- processQueue(eid);
- }
-}
-
-void QueuingRequester::handleResult(uint8_t eid, int result)
-{
- auto& queue = requestContextQueues[eid];
- const auto& reqCtx = queue.front();
-
- reqCtx->callback(result); // Call the original callback
-
- queue.pop_front();
-
- processQueue(eid);
-}
-
-void QueuingRequester::processQueue(uint8_t eid)
-{
- auto& queue = requestContextQueues[eid];
-
- if (queue.empty())
- {
- return;
- }
-
- const auto& reqCtx = queue.front();
-
- requester.sendRecvMsg(
- eid, reqCtx->reqMsg, reqCtx->respMsg,
- std::bind_front(&QueuingRequester::handleResult, this, eid));
+ boost::asio::const_buffer(req.data(), req.size()), sendEndPoint,
+ std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
}
} // namespace mctp