| /* |
| * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & |
| * AFFILIATES. All rights reserved. |
| * SPDX-License-Identifier: Apache-2.0 |
| */ |
| |
| #include "MctpRequester.hpp" |
| |
| #include <linux/mctp.h> |
| #include <sys/socket.h> |
| |
| #include <OcpMctpVdm.hpp> |
| #include <boost/asio/buffer.hpp> |
| #include <boost/asio/error.hpp> |
| #include <boost/asio/generic/datagram_protocol.hpp> |
| #include <boost/asio/io_context.hpp> |
| #include <boost/asio/steady_timer.hpp> |
| #include <boost/container/devector.hpp> |
| #include <phosphor-logging/lg2.hpp> |
| |
| #include <bit> |
| #include <cstddef> |
| #include <cstdint> |
| #include <cstring> |
| #include <expected> |
| #include <format> |
| #include <functional> |
| #include <optional> |
| #include <span> |
| #include <stdexcept> |
| #include <system_error> |
| #include <utility> |
| |
| using namespace std::literals; |
| |
| namespace mctp |
| { |
| |
| static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer( |
| std::span<const uint8_t> buffer) |
| { |
| if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) |
| { |
| return nullptr; |
| } |
| |
| return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>( |
| buffer.data()); |
| } |
| |
| static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer) |
| { |
| const ocp::accelerator_management::BindingPciVid* header = |
| getHeaderFromBuffer(buffer); |
| if (header == nullptr) |
| { |
| return std::nullopt; |
| } |
| return header->instance_id & ocp::accelerator_management::instanceIdBitMask; |
| } |
| |
| static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer) |
| { |
| const ocp::accelerator_management::BindingPciVid* header = |
| getHeaderFromBuffer(buffer); |
| if (header == nullptr) |
| { |
| return std::nullopt; |
| } |
| return header->instance_id & ocp::accelerator_management::requestBitMask; |
| } |
| |
| MctpRequester::MctpRequester(boost::asio::io_context& ctx) : |
| io{ctx}, |
| mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}) |
| { |
| startReceive(); |
| } |
| |
| void MctpRequester::startReceive() |
| { |
| mctpSocket.async_receive_from( |
| boost::asio::buffer(buffer), recvEndPoint.endpoint, |
| std::bind_front(&MctpRequester::processRecvMsg, this)); |
| } |
| |
| void MctpRequester::processRecvMsg(const boost::system::error_code& ec, |
| const size_t length) |
| { |
| std::optional<uint8_t> expectedEid = recvEndPoint.eid(); |
| std::optional<uint8_t> receivedMsgType = recvEndPoint.type(); |
| |
| if (!expectedEid || !receivedMsgType) |
| { |
| // we were handed an endpoint that can't be treated as an MCTP endpoint |
| // This is probably a kernel bug...yell about it and rebind. |
| lg2::error("MctpRequester: invalid endpoint"); |
| return; |
| } |
| |
| if (*receivedMsgType != msgType) |
| { |
| // we received a message that this handler doesn't support |
| // drop it on the floor and rebind receive_from |
| lg2::error("MctpRequester: Message type mismatch. We received {MSG}", |
| "MSG", *receivedMsgType); |
| return; |
| } |
| |
| uint8_t eid = *expectedEid; |
| |
| if (ec) |
| { |
| lg2::error( |
| "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.", |
| "EC", ec.value(), "ER", ec.message()); |
| handleResult(eid, static_cast<std::error_code>(ec), {}); |
| return; |
| } |
| |
| // if the received length was greater than our buffer, we would've truncated |
| // and gotten an error code in asio |
| std::span<const uint8_t> responseBuffer{buffer.data(), length}; |
| |
| std::optional<uint8_t> optionalIid = getIid(responseBuffer); |
| std::optional<bool> isRq = getRequestBit(responseBuffer); |
| if (!optionalIid || !isRq) |
| { |
| // we received something from the device, |
| // but we aren't able to parse iid byte |
| // drop this packet on the floor |
| // and rely on the timer to notify the client |
| lg2::error("MctpRequester: Unable to decode message from eid {EID}", |
| "EID", eid); |
| return; |
| } |
| |
| if (isRq.value()) |
| { |
| // we received a request from a downstream device. |
| // We don't currently support this, drop the packet |
| // on the floor and rebind receive, keep the timer running |
| return; |
| } |
| |
| uint8_t iid = *optionalIid; |
| |
| auto it = requestContextQueues.find(eid); |
| if (it == requestContextQueues.end()) |
| { |
| // something very bad has happened here |
| // we've received a packet that is a response |
| // from a device we've never talked to |
| // do our best and rebind receive and keep the timer running |
| lg2::error("Unable to match request to response"); |
| return; |
| } |
| |
| if (iid != it->second.iid) |
| { |
| // we received an iid that doesn't match the one we sent |
| // rebind async_receive_from and drop this packet on the floor |
| lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID", |
| iid, "EID", eid, "E_IID", it->second.iid); |
| return; |
| } |
| |
| handleResult(eid, std::error_code{}, responseBuffer); |
| } |
| |
| void MctpRequester::handleSendMsgCompletion( |
| uint8_t eid, const boost::system::error_code& ec, size_t /* length */) |
| { |
| if (ec) |
| { |
| lg2::error( |
| "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.", |
| "EC", ec.value(), "ER", ec.message()); |
| handleResult(eid, static_cast<std::error_code>(ec), {}); |
| return; |
| } |
| |
| auto it = requestContextQueues.find(eid); |
| if (it == requestContextQueues.end()) |
| { |
| // something very bad has happened here, |
| // we've sent something to a device that we have |
| // no record of. yell loudly and bail |
| lg2::error( |
| "MctpRequester completed send for an EID that we have no record of"); |
| return; |
| } |
| |
| boost::asio::steady_timer& expiryTimer = it->second.timer; |
| expiryTimer.expires_after(2s); |
| |
| expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) { |
| if (ec != boost::asio::error::operation_aborted) |
| { |
| lg2::error("Operation timed out on eid {EID}", "EID", eid); |
| handleResult(eid, std::make_error_code(std::errc::timed_out), {}); |
| } |
| }); |
| } |
| |
| void MctpRequester::sendRecvMsg( |
| uint8_t eid, std::span<const uint8_t> reqMsg, |
| std::move_only_function<void(const std::error_code&, |
| std::span<const uint8_t>)> |
| callback) |
| { |
| RequestContext reqCtx{reqMsg, std::move(callback)}; |
| |
| // try_emplace only affects the result if the key does not already exist |
| auto [it, inserted] = requestContextQueues.try_emplace(eid, io); |
| (void)inserted; |
| |
| auto& queue = it->second.queue; |
| queue.push_back(std::move(reqCtx)); |
| |
| if (queue.size() == 1) |
| { |
| processQueue(eid); |
| } |
| } |
| |
| static bool isFatalError(const std::error_code& ec) |
| { |
| return ec && |
| (ec != std::errc::timed_out && ec != std::errc::host_unreachable); |
| } |
| |
| void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec, |
| std::span<const uint8_t> buffer) |
| { |
| auto it = requestContextQueues.find(eid); |
| if (it == requestContextQueues.end()) |
| { |
| lg2::error("We tried to a handle a result for an eid we don't have"); |
| |
| startReceive(); |
| return; |
| } |
| |
| auto& queue = it->second.queue; |
| auto& reqCtx = queue.front(); |
| |
| it->second.timer.cancel(); |
| |
| reqCtx.callback(ec, buffer); // Call the original callback |
| |
| if (isFatalError(ec)) |
| { |
| // some errors are fatal, since these are datagrams, |
| // we won't get a receive path error message. |
| // and since this daemon services all nvidia iana commands |
| // for a given system, we should only restart the service if its |
| // unrecoverable, i.e. if we get error codes that the client |
| // can't reasonably deal with. If thats the cause, restart |
| // and hope that we can deal with it then. |
| // since we're fully async, the only reasonable way to bubble |
| // this issue up is to chuck an exception and let main deal with it. |
| // alternatively we could call cancel on the io_context, but there's |
| // not a great way to figure *what* happened. |
| throw std::runtime_error(std::format( |
| "eid {} encountered a fatal error: {}", eid, ec.message())); |
| } |
| |
| startReceive(); |
| |
| queue.pop_front(); |
| |
| processQueue(eid); |
| } |
| |
| std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid) |
| { |
| auto it = requestContextQueues.find(eid); |
| if (it == requestContextQueues.end()) |
| { |
| return std::nullopt; |
| } |
| |
| uint8_t& iid = it->second.iid; |
| ++iid; |
| iid &= ocp::accelerator_management::instanceIdBitMask; |
| return iid; |
| } |
| |
| static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer, |
| uint8_t iid) |
| { |
| if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) |
| { |
| return std::unexpected( |
| std::make_error_code(std::errc::invalid_argument)); |
| } |
| |
| if (iid > ocp::accelerator_management::instanceIdBitMask) |
| { |
| return std::unexpected( |
| std::make_error_code(std::errc::invalid_argument)); |
| } |
| |
| auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>( |
| buffer.data()); |
| |
| header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask; |
| header->instance_id |= iid; |
| return {}; |
| } |
| |
| void MctpRequester::processQueue(uint8_t eid) |
| { |
| auto it = requestContextQueues.find(eid); |
| if (it == requestContextQueues.end()) |
| { |
| lg2::error("We are attempting to process a queue that doesn't exist"); |
| return; |
| } |
| |
| auto& queue = it->second.queue; |
| |
| if (queue.empty()) |
| { |
| return; |
| } |
| auto& reqCtx = queue.front(); |
| |
| std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()}; |
| |
| std::optional<uint8_t> iid = getNextIid(eid); |
| if (!iid) |
| { |
| lg2::error("MctpRequester: Unable to get next iid"); |
| handleResult(eid, std::make_error_code(std::errc::no_such_device), {}); |
| return; |
| } |
| |
| std::expected<void, std::error_code> success = injectIid(req, *iid); |
| if (!success) |
| { |
| lg2::error("MctpRequester: unable to set iid"); |
| handleResult(eid, success.error(), {}); |
| return; |
| } |
| |
| struct sockaddr_mctp addr{}; |
| addr.smctp_family = AF_MCTP; |
| addr.smctp_addr.s_addr = eid; |
| addr.smctp_type = msgType; |
| addr.smctp_tag = MCTP_TAG_OWNER; |
| |
| sendEndPoint = {&addr, sizeof(addr)}; |
| |
| mctpSocket.async_send_to( |
| boost::asio::const_buffer(req.data(), req.size()), sendEndPoint, |
| std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid)); |
| } |
| |
| } // namespace mctp |