blob: f28371ffc109f52374b28c97b6d6c69b4f03eb83 [file] [log] [blame]
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#include "MctpRequester.hpp"
#include <linux/mctp.h>
#include <sys/socket.h>
#include <OcpMctpVdm.hpp>
#include <boost/asio/buffer.hpp>
#include <boost/asio/error.hpp>
#include <boost/asio/generic/datagram_protocol.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/steady_timer.hpp>
#include <boost/container/devector.hpp>
#include <phosphor-logging/lg2.hpp>
#include <bit>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <expected>
#include <format>
#include <functional>
#include <optional>
#include <span>
#include <stdexcept>
#include <system_error>
#include <utility>
using namespace std::literals;
namespace mctp
{
static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer(
std::span<const uint8_t> buffer)
{
if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
{
return nullptr;
}
return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>(
buffer.data());
}
static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer)
{
const ocp::accelerator_management::BindingPciVid* header =
getHeaderFromBuffer(buffer);
if (header == nullptr)
{
return std::nullopt;
}
return header->instance_id & ocp::accelerator_management::instanceIdBitMask;
}
static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer)
{
const ocp::accelerator_management::BindingPciVid* header =
getHeaderFromBuffer(buffer);
if (header == nullptr)
{
return std::nullopt;
}
return header->instance_id & ocp::accelerator_management::requestBitMask;
}
MctpRequester::MctpRequester(boost::asio::io_context& ctx) :
io{ctx},
mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0})
{
startReceive();
}
void MctpRequester::startReceive()
{
mctpSocket.async_receive_from(
boost::asio::buffer(buffer), recvEndPoint.endpoint,
std::bind_front(&MctpRequester::processRecvMsg, this));
}
void MctpRequester::processRecvMsg(const boost::system::error_code& ec,
const size_t length)
{
std::optional<uint8_t> expectedEid = recvEndPoint.eid();
std::optional<uint8_t> receivedMsgType = recvEndPoint.type();
if (!expectedEid || !receivedMsgType)
{
// we were handed an endpoint that can't be treated as an MCTP endpoint
// This is probably a kernel bug...yell about it and rebind.
lg2::error("MctpRequester: invalid endpoint");
return;
}
if (*receivedMsgType != msgType)
{
// we received a message that this handler doesn't support
// drop it on the floor and rebind receive_from
lg2::error("MctpRequester: Message type mismatch. We received {MSG}",
"MSG", *receivedMsgType);
return;
}
uint8_t eid = *expectedEid;
if (ec)
{
lg2::error(
"MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
"EC", ec.value(), "ER", ec.message());
handleResult(eid, static_cast<std::error_code>(ec), {});
return;
}
// if the received length was greater than our buffer, we would've truncated
// and gotten an error code in asio
std::span<const uint8_t> responseBuffer{buffer.data(), length};
std::optional<uint8_t> optionalIid = getIid(responseBuffer);
std::optional<bool> isRq = getRequestBit(responseBuffer);
if (!optionalIid || !isRq)
{
// we received something from the device,
// but we aren't able to parse iid byte
// drop this packet on the floor
// and rely on the timer to notify the client
lg2::error("MctpRequester: Unable to decode message from eid {EID}",
"EID", eid);
return;
}
if (isRq.value())
{
// we received a request from a downstream device.
// We don't currently support this, drop the packet
// on the floor and rebind receive, keep the timer running
return;
}
uint8_t iid = *optionalIid;
auto it = requestContextQueues.find(eid);
if (it == requestContextQueues.end())
{
// something very bad has happened here
// we've received a packet that is a response
// from a device we've never talked to
// do our best and rebind receive and keep the timer running
lg2::error("Unable to match request to response");
return;
}
if (iid != it->second.iid)
{
// we received an iid that doesn't match the one we sent
// rebind async_receive_from and drop this packet on the floor
lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID",
iid, "EID", eid, "E_IID", it->second.iid);
return;
}
handleResult(eid, std::error_code{}, responseBuffer);
}
void MctpRequester::handleSendMsgCompletion(
uint8_t eid, const boost::system::error_code& ec, size_t /* length */)
{
if (ec)
{
lg2::error(
"MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
"EC", ec.value(), "ER", ec.message());
handleResult(eid, static_cast<std::error_code>(ec), {});
return;
}
auto it = requestContextQueues.find(eid);
if (it == requestContextQueues.end())
{
// something very bad has happened here,
// we've sent something to a device that we have
// no record of. yell loudly and bail
lg2::error(
"MctpRequester completed send for an EID that we have no record of");
return;
}
boost::asio::steady_timer& expiryTimer = it->second.timer;
expiryTimer.expires_after(2s);
expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) {
if (ec != boost::asio::error::operation_aborted)
{
lg2::error("Operation timed out on eid {EID}", "EID", eid);
handleResult(eid, std::make_error_code(std::errc::timed_out), {});
}
});
}
void MctpRequester::sendRecvMsg(
uint8_t eid, std::span<const uint8_t> reqMsg,
std::move_only_function<void(const std::error_code&,
std::span<const uint8_t>)>
callback)
{
RequestContext reqCtx{reqMsg, std::move(callback)};
// try_emplace only affects the result if the key does not already exist
auto [it, inserted] = requestContextQueues.try_emplace(eid, io);
(void)inserted;
auto& queue = it->second.queue;
queue.push_back(std::move(reqCtx));
if (queue.size() == 1)
{
processQueue(eid);
}
}
static bool isFatalError(const std::error_code& ec)
{
return ec &&
(ec != std::errc::timed_out && ec != std::errc::host_unreachable);
}
void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec,
std::span<const uint8_t> buffer)
{
auto it = requestContextQueues.find(eid);
if (it == requestContextQueues.end())
{
lg2::error("We tried to a handle a result for an eid we don't have");
startReceive();
return;
}
auto& queue = it->second.queue;
auto& reqCtx = queue.front();
it->second.timer.cancel();
reqCtx.callback(ec, buffer); // Call the original callback
if (isFatalError(ec))
{
// some errors are fatal, since these are datagrams,
// we won't get a receive path error message.
// and since this daemon services all nvidia iana commands
// for a given system, we should only restart the service if its
// unrecoverable, i.e. if we get error codes that the client
// can't reasonably deal with. If thats the cause, restart
// and hope that we can deal with it then.
// since we're fully async, the only reasonable way to bubble
// this issue up is to chuck an exception and let main deal with it.
// alternatively we could call cancel on the io_context, but there's
// not a great way to figure *what* happened.
throw std::runtime_error(std::format(
"eid {} encountered a fatal error: {}", eid, ec.message()));
}
startReceive();
queue.pop_front();
processQueue(eid);
}
std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid)
{
auto it = requestContextQueues.find(eid);
if (it == requestContextQueues.end())
{
return std::nullopt;
}
uint8_t& iid = it->second.iid;
++iid;
iid &= ocp::accelerator_management::instanceIdBitMask;
return iid;
}
static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer,
uint8_t iid)
{
if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid))
{
return std::unexpected(
std::make_error_code(std::errc::invalid_argument));
}
if (iid > ocp::accelerator_management::instanceIdBitMask)
{
return std::unexpected(
std::make_error_code(std::errc::invalid_argument));
}
auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>(
buffer.data());
header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask;
header->instance_id |= iid;
return {};
}
void MctpRequester::processQueue(uint8_t eid)
{
auto it = requestContextQueues.find(eid);
if (it == requestContextQueues.end())
{
lg2::error("We are attempting to process a queue that doesn't exist");
return;
}
auto& queue = it->second.queue;
if (queue.empty())
{
return;
}
auto& reqCtx = queue.front();
std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()};
std::optional<uint8_t> iid = getNextIid(eid);
if (!iid)
{
lg2::error("MctpRequester: Unable to get next iid");
handleResult(eid, std::make_error_code(std::errc::no_such_device), {});
return;
}
std::expected<void, std::error_code> success = injectIid(req, *iid);
if (!success)
{
lg2::error("MctpRequester: unable to set iid");
handleResult(eid, success.error(), {});
return;
}
struct sockaddr_mctp addr{};
addr.smctp_family = AF_MCTP;
addr.smctp_addr.s_addr = eid;
addr.smctp_type = msgType;
addr.smctp_tag = MCTP_TAG_OWNER;
sendEndPoint = {&addr, sizeof(addr)};
mctpSocket.async_send_to(
boost::asio::const_buffer(req.data(), req.size()), sendEndPoint,
std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid));
}
} // namespace mctp