| /* | 
 |  * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & | 
 |  * AFFILIATES. All rights reserved. | 
 |  * SPDX-License-Identifier: Apache-2.0 | 
 |  */ | 
 |  | 
 | #include "MctpRequester.hpp" | 
 |  | 
 | #include <linux/mctp.h> | 
 | #include <sys/socket.h> | 
 |  | 
 | #include <OcpMctpVdm.hpp> | 
 | #include <boost/asio/buffer.hpp> | 
 | #include <boost/asio/error.hpp> | 
 | #include <boost/asio/generic/datagram_protocol.hpp> | 
 | #include <boost/asio/io_context.hpp> | 
 | #include <boost/asio/steady_timer.hpp> | 
 | #include <boost/container/devector.hpp> | 
 | #include <phosphor-logging/lg2.hpp> | 
 |  | 
 | #include <bit> | 
 | #include <cstddef> | 
 | #include <cstdint> | 
 | #include <cstring> | 
 | #include <expected> | 
 | #include <format> | 
 | #include <functional> | 
 | #include <optional> | 
 | #include <span> | 
 | #include <stdexcept> | 
 | #include <system_error> | 
 | #include <utility> | 
 |  | 
 | using namespace std::literals; | 
 |  | 
 | namespace mctp | 
 | { | 
 |  | 
 | static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer( | 
 |     std::span<const uint8_t> buffer) | 
 | { | 
 |     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) | 
 |     { | 
 |         return nullptr; | 
 |     } | 
 |  | 
 |     return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>( | 
 |         buffer.data()); | 
 | } | 
 |  | 
 | static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer) | 
 | { | 
 |     const ocp::accelerator_management::BindingPciVid* header = | 
 |         getHeaderFromBuffer(buffer); | 
 |     if (header == nullptr) | 
 |     { | 
 |         return std::nullopt; | 
 |     } | 
 |     return header->instance_id & ocp::accelerator_management::instanceIdBitMask; | 
 | } | 
 |  | 
 | static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer) | 
 | { | 
 |     const ocp::accelerator_management::BindingPciVid* header = | 
 |         getHeaderFromBuffer(buffer); | 
 |     if (header == nullptr) | 
 |     { | 
 |         return std::nullopt; | 
 |     } | 
 |     return header->instance_id & ocp::accelerator_management::requestBitMask; | 
 | } | 
 |  | 
 | MctpRequester::MctpRequester(boost::asio::io_context& ctx) : | 
 |     io{ctx}, | 
 |     mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}) | 
 | { | 
 |     startReceive(); | 
 | } | 
 |  | 
 | void MctpRequester::startReceive() | 
 | { | 
 |     mctpSocket.async_receive_from( | 
 |         boost::asio::buffer(buffer), recvEndPoint.endpoint, | 
 |         std::bind_front(&MctpRequester::processRecvMsg, this)); | 
 | } | 
 |  | 
 | void MctpRequester::processRecvMsg(const boost::system::error_code& ec, | 
 |                                    const size_t length) | 
 | { | 
 |     std::optional<uint8_t> expectedEid = recvEndPoint.eid(); | 
 |     std::optional<uint8_t> receivedMsgType = recvEndPoint.type(); | 
 |  | 
 |     if (!expectedEid || !receivedMsgType) | 
 |     { | 
 |         // we were handed an endpoint that can't be treated as an MCTP endpoint | 
 |         // This is probably a kernel bug...yell about it and rebind. | 
 |         lg2::error("MctpRequester: invalid endpoint"); | 
 |         return; | 
 |     } | 
 |  | 
 |     if (*receivedMsgType != msgType) | 
 |     { | 
 |         // we received a message that this handler doesn't support | 
 |         // drop it on the floor and rebind receive_from | 
 |         lg2::error("MctpRequester: Message type mismatch. We received {MSG}", | 
 |                    "MSG", *receivedMsgType); | 
 |         return; | 
 |     } | 
 |  | 
 |     uint8_t eid = *expectedEid; | 
 |  | 
 |     if (ec) | 
 |     { | 
 |         lg2::error( | 
 |             "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.", | 
 |             "EC", ec.value(), "ER", ec.message()); | 
 |         handleResult(eid, static_cast<std::error_code>(ec), {}); | 
 |         return; | 
 |     } | 
 |  | 
 |     // if the received length was greater than our buffer, we would've truncated | 
 |     // and gotten an error code in asio | 
 |     std::span<const uint8_t> responseBuffer{buffer.data(), length}; | 
 |  | 
 |     std::optional<uint8_t> optionalIid = getIid(responseBuffer); | 
 |     std::optional<bool> isRq = getRequestBit(responseBuffer); | 
 |     if (!optionalIid || !isRq) | 
 |     { | 
 |         // we received something from the device, | 
 |         // but we aren't able to parse iid byte | 
 |         // drop this packet on the floor | 
 |         // and rely on the timer to notify the client | 
 |         lg2::error("MctpRequester: Unable to decode message from eid {EID}", | 
 |                    "EID", eid); | 
 |         return; | 
 |     } | 
 |  | 
 |     if (isRq.value()) | 
 |     { | 
 |         // we received a request from a downstream device. | 
 |         // We don't currently support this, drop the packet | 
 |         // on the floor and rebind receive, keep the timer running | 
 |         return; | 
 |     } | 
 |  | 
 |     uint8_t iid = *optionalIid; | 
 |  | 
 |     auto it = requestContextQueues.find(eid); | 
 |     if (it == requestContextQueues.end()) | 
 |     { | 
 |         // something very bad has happened here | 
 |         // we've received a packet that is a response | 
 |         // from a device we've never talked to | 
 |         // do our best and rebind receive and keep the timer running | 
 |         lg2::error("Unable to match request to response"); | 
 |         return; | 
 |     } | 
 |  | 
 |     if (iid != it->second.iid) | 
 |     { | 
 |         // we received an iid that doesn't match the one we sent | 
 |         // rebind async_receive_from and drop this packet on the floor | 
 |         lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID", | 
 |                    iid, "EID", eid, "E_IID", it->second.iid); | 
 |         return; | 
 |     } | 
 |  | 
 |     handleResult(eid, std::error_code{}, responseBuffer); | 
 | } | 
 |  | 
 | void MctpRequester::handleSendMsgCompletion( | 
 |     uint8_t eid, const boost::system::error_code& ec, size_t /* length */) | 
 | { | 
 |     if (ec) | 
 |     { | 
 |         lg2::error( | 
 |             "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.", | 
 |             "EC", ec.value(), "ER", ec.message()); | 
 |         handleResult(eid, static_cast<std::error_code>(ec), {}); | 
 |         return; | 
 |     } | 
 |  | 
 |     auto it = requestContextQueues.find(eid); | 
 |     if (it == requestContextQueues.end()) | 
 |     { | 
 |         // something very bad has happened here, | 
 |         // we've sent something to a device that we have | 
 |         // no record of. yell loudly and bail | 
 |         lg2::error( | 
 |             "MctpRequester completed send for an EID that we have no record of"); | 
 |         return; | 
 |     } | 
 |  | 
 |     boost::asio::steady_timer& expiryTimer = it->second.timer; | 
 |     expiryTimer.expires_after(2s); | 
 |  | 
 |     expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) { | 
 |         if (ec != boost::asio::error::operation_aborted) | 
 |         { | 
 |             lg2::error("Operation timed out on eid {EID}", "EID", eid); | 
 |             handleResult(eid, std::make_error_code(std::errc::timed_out), {}); | 
 |         } | 
 |     }); | 
 | } | 
 |  | 
 | void MctpRequester::sendRecvMsg( | 
 |     uint8_t eid, std::span<const uint8_t> reqMsg, | 
 |     std::move_only_function<void(const std::error_code&, | 
 |                                  std::span<const uint8_t>)> | 
 |         callback) | 
 | { | 
 |     RequestContext reqCtx{reqMsg, std::move(callback)}; | 
 |  | 
 |     // try_emplace only affects the result if the key does not already exist | 
 |     auto [it, inserted] = requestContextQueues.try_emplace(eid, io); | 
 |     (void)inserted; | 
 |  | 
 |     auto& queue = it->second.queue; | 
 |     queue.push_back(std::move(reqCtx)); | 
 |  | 
 |     if (queue.size() == 1) | 
 |     { | 
 |         processQueue(eid); | 
 |     } | 
 | } | 
 |  | 
 | static bool isFatalError(const std::error_code& ec) | 
 | { | 
 |     return ec && | 
 |            (ec != std::errc::timed_out && ec != std::errc::host_unreachable); | 
 | } | 
 |  | 
 | void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec, | 
 |                                  std::span<const uint8_t> buffer) | 
 | { | 
 |     auto it = requestContextQueues.find(eid); | 
 |     if (it == requestContextQueues.end()) | 
 |     { | 
 |         lg2::error("We tried to a handle a result for an eid we don't have"); | 
 |  | 
 |         startReceive(); | 
 |         return; | 
 |     } | 
 |  | 
 |     auto& queue = it->second.queue; | 
 |     auto& reqCtx = queue.front(); | 
 |  | 
 |     it->second.timer.cancel(); | 
 |  | 
 |     reqCtx.callback(ec, buffer); // Call the original callback | 
 |  | 
 |     if (isFatalError(ec)) | 
 |     { | 
 |         // some errors are fatal, since these are datagrams, | 
 |         // we won't get a receive path error message. | 
 |         // and since this daemon services all nvidia iana commands | 
 |         // for a given system, we should only restart the service if its | 
 |         // unrecoverable, i.e. if we get error codes that the client | 
 |         // can't reasonably deal with. If thats the cause, restart | 
 |         // and hope that we can deal with it then. | 
 |         // since we're fully async, the only reasonable way to bubble | 
 |         // this issue up is to chuck an exception and let main deal with it. | 
 |         // alternatively we could call cancel on the io_context, but there's | 
 |         // not a great way to figure *what* happened. | 
 |         throw std::runtime_error(std::format( | 
 |             "eid {} encountered a fatal error: {}", eid, ec.message())); | 
 |     } | 
 |  | 
 |     startReceive(); | 
 |  | 
 |     queue.pop_front(); | 
 |  | 
 |     processQueue(eid); | 
 | } | 
 |  | 
 | std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid) | 
 | { | 
 |     auto it = requestContextQueues.find(eid); | 
 |     if (it == requestContextQueues.end()) | 
 |     { | 
 |         return std::nullopt; | 
 |     } | 
 |  | 
 |     uint8_t& iid = it->second.iid; | 
 |     ++iid; | 
 |     iid &= ocp::accelerator_management::instanceIdBitMask; | 
 |     return iid; | 
 | } | 
 |  | 
 | static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer, | 
 |                                                       uint8_t iid) | 
 | { | 
 |     if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) | 
 |     { | 
 |         return std::unexpected( | 
 |             std::make_error_code(std::errc::invalid_argument)); | 
 |     } | 
 |  | 
 |     if (iid > ocp::accelerator_management::instanceIdBitMask) | 
 |     { | 
 |         return std::unexpected( | 
 |             std::make_error_code(std::errc::invalid_argument)); | 
 |     } | 
 |  | 
 |     auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>( | 
 |         buffer.data()); | 
 |  | 
 |     header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask; | 
 |     header->instance_id |= iid; | 
 |     return {}; | 
 | } | 
 |  | 
 | void MctpRequester::processQueue(uint8_t eid) | 
 | { | 
 |     auto it = requestContextQueues.find(eid); | 
 |     if (it == requestContextQueues.end()) | 
 |     { | 
 |         lg2::error("We are attempting to process a queue that doesn't exist"); | 
 |         return; | 
 |     } | 
 |  | 
 |     auto& queue = it->second.queue; | 
 |  | 
 |     if (queue.empty()) | 
 |     { | 
 |         return; | 
 |     } | 
 |     auto& reqCtx = queue.front(); | 
 |  | 
 |     std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()}; | 
 |  | 
 |     std::optional<uint8_t> iid = getNextIid(eid); | 
 |     if (!iid) | 
 |     { | 
 |         lg2::error("MctpRequester: Unable to get next iid"); | 
 |         handleResult(eid, std::make_error_code(std::errc::no_such_device), {}); | 
 |         return; | 
 |     } | 
 |  | 
 |     std::expected<void, std::error_code> success = injectIid(req, *iid); | 
 |     if (!success) | 
 |     { | 
 |         lg2::error("MctpRequester: unable to set iid"); | 
 |         handleResult(eid, success.error(), {}); | 
 |         return; | 
 |     } | 
 |  | 
 |     struct sockaddr_mctp addr{}; | 
 |     addr.smctp_family = AF_MCTP; | 
 |     addr.smctp_addr.s_addr = eid; | 
 |     addr.smctp_type = msgType; | 
 |     addr.smctp_tag = MCTP_TAG_OWNER; | 
 |  | 
 |     sendEndPoint = {&addr, sizeof(addr)}; | 
 |  | 
 |     mctpSocket.async_send_to( | 
 |         boost::asio::const_buffer(req.data(), req.size()), sendEndPoint, | 
 |         std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid)); | 
 | } | 
 |  | 
 | } // namespace mctp |