| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 1 | /* |
| 2 | * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & |
| 3 | * AFFILIATES. All rights reserved. |
| 4 | * SPDX-License-Identifier: Apache-2.0 |
| 5 | */ |
| 6 | |
| 7 | #include "MctpRequester.hpp" |
| 8 | |
| 9 | #include <linux/mctp.h> |
| 10 | #include <sys/socket.h> |
| 11 | |
| 12 | #include <OcpMctpVdm.hpp> |
| 13 | #include <boost/asio/buffer.hpp> |
| 14 | #include <boost/asio/error.hpp> |
| 15 | #include <boost/asio/generic/datagram_protocol.hpp> |
| 16 | #include <boost/asio/io_context.hpp> |
| 17 | #include <boost/asio/steady_timer.hpp> |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 18 | #include <boost/container/devector.hpp> |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 19 | #include <phosphor-logging/lg2.hpp> |
| 20 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 21 | #include <bit> |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 22 | #include <cstddef> |
| 23 | #include <cstdint> |
| 24 | #include <cstring> |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 25 | #include <expected> |
| 26 | #include <format> |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 27 | #include <functional> |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 28 | #include <optional> |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 29 | #include <span> |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 30 | #include <stdexcept> |
| 31 | #include <system_error> |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 32 | #include <utility> |
| 33 | |
| 34 | using namespace std::literals; |
| 35 | |
| 36 | namespace mctp |
| 37 | { |
| 38 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 39 | static const ocp::accelerator_management::BindingPciVid* getHeaderFromBuffer( |
| 40 | std::span<const uint8_t> buffer) |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 41 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 42 | if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 43 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 44 | return nullptr; |
| 45 | } |
| 46 | |
| 47 | return std::bit_cast<const ocp::accelerator_management::BindingPciVid*>( |
| 48 | buffer.data()); |
| 49 | } |
| 50 | |
| 51 | static std::optional<uint8_t> getIid(std::span<const uint8_t> buffer) |
| 52 | { |
| 53 | const ocp::accelerator_management::BindingPciVid* header = |
| 54 | getHeaderFromBuffer(buffer); |
| 55 | if (header == nullptr) |
| 56 | { |
| 57 | return std::nullopt; |
| 58 | } |
| 59 | return header->instance_id & ocp::accelerator_management::instanceIdBitMask; |
| 60 | } |
| 61 | |
| 62 | static std::optional<bool> getRequestBit(std::span<const uint8_t> buffer) |
| 63 | { |
| 64 | const ocp::accelerator_management::BindingPciVid* header = |
| 65 | getHeaderFromBuffer(buffer); |
| 66 | if (header == nullptr) |
| 67 | { |
| 68 | return std::nullopt; |
| 69 | } |
| 70 | return header->instance_id & ocp::accelerator_management::requestBitMask; |
| 71 | } |
| 72 | |
| 73 | MctpRequester::MctpRequester(boost::asio::io_context& ctx) : |
| 74 | io{ctx}, |
| 75 | mctpSocket(ctx, boost::asio::generic::datagram_protocol{AF_MCTP, 0}) |
| 76 | { |
| 77 | startReceive(); |
| 78 | } |
| 79 | |
| 80 | void MctpRequester::startReceive() |
| 81 | { |
| 82 | mctpSocket.async_receive_from( |
| 83 | boost::asio::buffer(buffer), recvEndPoint.endpoint, |
| 84 | std::bind_front(&MctpRequester::processRecvMsg, this)); |
| 85 | } |
| 86 | |
| 87 | void MctpRequester::processRecvMsg(const boost::system::error_code& ec, |
| 88 | const size_t length) |
| 89 | { |
| 90 | std::optional<uint8_t> expectedEid = recvEndPoint.eid(); |
| 91 | std::optional<uint8_t> receivedMsgType = recvEndPoint.type(); |
| 92 | |
| 93 | if (!expectedEid || !receivedMsgType) |
| 94 | { |
| 95 | // we were handed an endpoint that can't be treated as an MCTP endpoint |
| 96 | // This is probably a kernel bug...yell about it and rebind. |
| 97 | lg2::error("MctpRequester: invalid endpoint"); |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 98 | return; |
| 99 | } |
| 100 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 101 | if (*receivedMsgType != msgType) |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 102 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 103 | // we received a message that this handler doesn't support |
| 104 | // drop it on the floor and rebind receive_from |
| 105 | lg2::error("MctpRequester: Message type mismatch. We received {MSG}", |
| 106 | "MSG", *receivedMsgType); |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 107 | return; |
| 108 | } |
| 109 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 110 | uint8_t eid = *expectedEid; |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 111 | |
| 112 | if (ec) |
| 113 | { |
| 114 | lg2::error( |
| 115 | "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.", |
| 116 | "EC", ec.value(), "ER", ec.message()); |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 117 | handleResult(eid, static_cast<std::error_code>(ec), {}); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 118 | return; |
| 119 | } |
| 120 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 121 | // if the received length was greater than our buffer, we would've truncated |
| 122 | // and gotten an error code in asio |
| 123 | std::span<const uint8_t> responseBuffer{buffer.data(), length}; |
| 124 | |
| 125 | std::optional<uint8_t> optionalIid = getIid(responseBuffer); |
| 126 | std::optional<bool> isRq = getRequestBit(responseBuffer); |
| 127 | if (!optionalIid || !isRq) |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 128 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 129 | // we received something from the device, |
| 130 | // but we aren't able to parse iid byte |
| 131 | // drop this packet on the floor |
| 132 | // and rely on the timer to notify the client |
| 133 | lg2::error("MctpRequester: Unable to decode message from eid {EID}", |
| 134 | "EID", eid); |
| 135 | return; |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 136 | } |
| 137 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 138 | if (isRq.value()) |
| 139 | { |
| 140 | // we received a request from a downstream device. |
| 141 | // We don't currently support this, drop the packet |
| 142 | // on the floor and rebind receive, keep the timer running |
| 143 | return; |
| 144 | } |
| 145 | |
| 146 | uint8_t iid = *optionalIid; |
| 147 | |
| 148 | auto it = requestContextQueues.find(eid); |
| 149 | if (it == requestContextQueues.end()) |
| 150 | { |
| 151 | // something very bad has happened here |
| 152 | // we've received a packet that is a response |
| 153 | // from a device we've never talked to |
| 154 | // do our best and rebind receive and keep the timer running |
| 155 | lg2::error("Unable to match request to response"); |
| 156 | return; |
| 157 | } |
| 158 | |
| 159 | if (iid != it->second.iid) |
| 160 | { |
| 161 | // we received an iid that doesn't match the one we sent |
| 162 | // rebind async_receive_from and drop this packet on the floor |
| 163 | lg2::error("Invalid iid {IID} from eid {EID}, expected {E_IID}", "IID", |
| 164 | iid, "EID", eid, "E_IID", it->second.iid); |
| 165 | return; |
| 166 | } |
| 167 | |
| 168 | handleResult(eid, std::error_code{}, responseBuffer); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 169 | } |
| 170 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 171 | void MctpRequester::handleSendMsgCompletion( |
| 172 | uint8_t eid, const boost::system::error_code& ec, size_t /* length */) |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 173 | { |
| 174 | if (ec) |
| 175 | { |
| 176 | lg2::error( |
| 177 | "MctpRequester failed to send data from the MCTP socket - ErrorCode={EC}, Error={ER}.", |
| 178 | "EC", ec.value(), "ER", ec.message()); |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 179 | handleResult(eid, static_cast<std::error_code>(ec), {}); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 180 | return; |
| 181 | } |
| 182 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 183 | auto it = requestContextQueues.find(eid); |
| 184 | if (it == requestContextQueues.end()) |
| 185 | { |
| 186 | // something very bad has happened here, |
| 187 | // we've sent something to a device that we have |
| 188 | // no record of. yell loudly and bail |
| 189 | lg2::error( |
| 190 | "MctpRequester completed send for an EID that we have no record of"); |
| 191 | return; |
| 192 | } |
| 193 | |
| 194 | boost::asio::steady_timer& expiryTimer = it->second.timer; |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 195 | expiryTimer.expires_after(2s); |
| 196 | |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 197 | expiryTimer.async_wait([this, eid](const boost::system::error_code& ec) { |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 198 | if (ec != boost::asio::error::operation_aborted) |
| 199 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 200 | lg2::error("Operation timed out on eid {EID}", "EID", eid); |
| 201 | handleResult(eid, std::make_error_code(std::errc::timed_out), {}); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 202 | } |
| 203 | }); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 204 | } |
| 205 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 206 | void MctpRequester::sendRecvMsg( |
| 207 | uint8_t eid, std::span<const uint8_t> reqMsg, |
| 208 | std::move_only_function<void(const std::error_code&, |
| 209 | std::span<const uint8_t>)> |
| 210 | callback) |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 211 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 212 | RequestContext reqCtx{reqMsg, std::move(callback)}; |
| 213 | |
| 214 | // try_emplace only affects the result if the key does not already exist |
| 215 | auto [it, inserted] = requestContextQueues.try_emplace(eid, io); |
| 216 | (void)inserted; |
| 217 | |
| 218 | auto& queue = it->second.queue; |
| 219 | queue.push_back(std::move(reqCtx)); |
| 220 | |
| 221 | if (queue.size() == 1) |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 222 | { |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 223 | processQueue(eid); |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | static bool isFatalError(const std::error_code& ec) |
| 228 | { |
| 229 | return ec && |
| 230 | (ec != std::errc::timed_out && ec != std::errc::host_unreachable); |
| 231 | } |
| 232 | |
| 233 | void MctpRequester::handleResult(uint8_t eid, const std::error_code& ec, |
| 234 | std::span<const uint8_t> buffer) |
| 235 | { |
| 236 | auto it = requestContextQueues.find(eid); |
| 237 | if (it == requestContextQueues.end()) |
| 238 | { |
| 239 | lg2::error("We tried to a handle a result for an eid we don't have"); |
| 240 | |
| 241 | startReceive(); |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 242 | return; |
| 243 | } |
| 244 | |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 245 | auto& queue = it->second.queue; |
| 246 | auto& reqCtx = queue.front(); |
| 247 | |
| 248 | it->second.timer.cancel(); |
| 249 | |
| 250 | reqCtx.callback(ec, buffer); // Call the original callback |
| 251 | |
| 252 | if (isFatalError(ec)) |
| 253 | { |
| 254 | // some errors are fatal, since these are datagrams, |
| 255 | // we won't get a receive path error message. |
| 256 | // and since this daemon services all nvidia iana commands |
| 257 | // for a given system, we should only restart the service if its |
| 258 | // unrecoverable, i.e. if we get error codes that the client |
| 259 | // can't reasonably deal with. If thats the cause, restart |
| 260 | // and hope that we can deal with it then. |
| 261 | // since we're fully async, the only reasonable way to bubble |
| 262 | // this issue up is to chuck an exception and let main deal with it. |
| 263 | // alternatively we could call cancel on the io_context, but there's |
| 264 | // not a great way to figure *what* happened. |
| 265 | throw std::runtime_error(std::format( |
| 266 | "eid {} encountered a fatal error: {}", eid, ec.message())); |
| 267 | } |
| 268 | |
| 269 | startReceive(); |
| 270 | |
| 271 | queue.pop_front(); |
| 272 | |
| 273 | processQueue(eid); |
| 274 | } |
| 275 | |
| 276 | std::optional<uint8_t> MctpRequester::getNextIid(uint8_t eid) |
| 277 | { |
| 278 | auto it = requestContextQueues.find(eid); |
| 279 | if (it == requestContextQueues.end()) |
| 280 | { |
| 281 | return std::nullopt; |
| 282 | } |
| 283 | |
| 284 | uint8_t& iid = it->second.iid; |
| 285 | ++iid; |
| 286 | iid &= ocp::accelerator_management::instanceIdBitMask; |
| 287 | return iid; |
| 288 | } |
| 289 | |
| 290 | static std::expected<void, std::error_code> injectIid(std::span<uint8_t> buffer, |
| 291 | uint8_t iid) |
| 292 | { |
| 293 | if (buffer.size() < sizeof(ocp::accelerator_management::BindingPciVid)) |
| 294 | { |
| 295 | return std::unexpected( |
| 296 | std::make_error_code(std::errc::invalid_argument)); |
| 297 | } |
| 298 | |
| 299 | if (iid > ocp::accelerator_management::instanceIdBitMask) |
| 300 | { |
| 301 | return std::unexpected( |
| 302 | std::make_error_code(std::errc::invalid_argument)); |
| 303 | } |
| 304 | |
| 305 | auto* header = std::bit_cast<ocp::accelerator_management::BindingPciVid*>( |
| 306 | buffer.data()); |
| 307 | |
| 308 | header->instance_id &= ~ocp::accelerator_management::instanceIdBitMask; |
| 309 | header->instance_id |= iid; |
| 310 | return {}; |
| 311 | } |
| 312 | |
| 313 | void MctpRequester::processQueue(uint8_t eid) |
| 314 | { |
| 315 | auto it = requestContextQueues.find(eid); |
| 316 | if (it == requestContextQueues.end()) |
| 317 | { |
| 318 | lg2::error("We are attempting to process a queue that doesn't exist"); |
| 319 | return; |
| 320 | } |
| 321 | |
| 322 | auto& queue = it->second.queue; |
| 323 | |
| 324 | if (queue.empty()) |
| 325 | { |
| 326 | return; |
| 327 | } |
| 328 | auto& reqCtx = queue.front(); |
| 329 | |
| 330 | std::span<uint8_t> req{reqCtx.reqMsg.data(), reqCtx.reqMsg.size()}; |
| 331 | |
| 332 | std::optional<uint8_t> iid = getNextIid(eid); |
| 333 | if (!iid) |
| 334 | { |
| 335 | lg2::error("MctpRequester: Unable to get next iid"); |
| 336 | handleResult(eid, std::make_error_code(std::errc::no_such_device), {}); |
| 337 | return; |
| 338 | } |
| 339 | |
| 340 | std::expected<void, std::error_code> success = injectIid(req, *iid); |
| 341 | if (!success) |
| 342 | { |
| 343 | lg2::error("MctpRequester: unable to set iid"); |
| 344 | handleResult(eid, success.error(), {}); |
| 345 | return; |
| 346 | } |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 347 | |
| 348 | struct sockaddr_mctp addr{}; |
| 349 | addr.smctp_family = AF_MCTP; |
| 350 | addr.smctp_addr.s_addr = eid; |
| 351 | addr.smctp_type = msgType; |
| 352 | addr.smctp_tag = MCTP_TAG_OWNER; |
| 353 | |
| 354 | sendEndPoint = {&addr, sizeof(addr)}; |
| 355 | |
| 356 | mctpSocket.async_send_to( |
| Marc Olberding | d0125c9 | 2025-10-08 14:37:19 -0700 | [diff] [blame^] | 357 | boost::asio::const_buffer(req.data(), req.size()), sendEndPoint, |
| 358 | std::bind_front(&MctpRequester::handleSendMsgCompletion, this, eid)); |
| Aditya Kurdunkar | ed0af21 | 2025-06-11 04:38:52 +0530 | [diff] [blame] | 359 | } |
| 360 | |
| Harshit Aghera | 560e6af | 2025-04-21 20:04:56 +0530 | [diff] [blame] | 361 | } // namespace mctp |