blob: 912e14d0e3ed3df5b369821eb023938634d80f01 [file] [log] [blame]
Harshit Agheraa3f24f42025-04-21 20:04:56 +05301/*
2 * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION &
3 * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
4 */
5
6#include "MctpRequester.hpp"
7
8#include <linux/mctp.h>
9#include <sys/socket.h>
10#include <unistd.h>
11
12#include <OcpMctpVdm.hpp>
13#include <boost/asio/buffer.hpp>
14#include <boost/asio/error.hpp>
15#include <boost/asio/io_context.hpp>
16#include <boost/asio/local/datagram_protocol.hpp>
17#include <boost/asio/steady_timer.hpp>
18#include <phosphor-logging/lg2.hpp>
19
20#include <cstddef>
21#include <cstdint>
22#include <cstring>
23#include <functional>
24#include <memory>
25#include <utility>
26#include <vector>
27
28using namespace std::literals;
29
30namespace mctp
31{
32
33MctpRequester::MctpRequester(boost::asio::io_context& ctx, uint8_t msgType) :
34 ctx(ctx), sockfd(socket(AF_MCTP, SOCK_DGRAM, 0)), mctpSocket(ctx),
35 msgType(msgType)
36{
37 if (sockfd < 0)
38 {
39 lg2::error("Failed to create MCTP socket");
40 return;
41 }
42
43 boost::system::error_code ec;
44 mctpSocket.assign(boost::asio::local::datagram_protocol{}, sockfd, ec);
45
46 if (ec)
47 {
48 lg2::error(
49 "MctpRequester failed to connect to the MCTP socket - ErrorCode={EC}, Error={ER}.",
50 "EC", ec.value(), "ER", ec.message());
51 close(sockfd);
52 return;
53 }
54
55 mctpSocket.non_blocking(true);
56}
57
58void MctpRequester::processRecvMsg(
59 mctp_eid_t eid, const std::vector<uint8_t>& reqMsg,
60 const std::function<void(int, std::vector<uint8_t>)>& callback,
61 size_t peekedLength) const
62{
63 // Receive message
64 struct sockaddr sockAddr{};
65 struct sockaddr_mctp respAddr{};
66 socklen_t addrlen = sizeof(respAddr);
67 size_t receivedLength = 0;
68
69 std::vector<uint8_t> fullRespMsg(peekedLength);
70
71 receivedLength = recvfrom(sockfd, fullRespMsg.data(), peekedLength,
72 MSG_TRUNC, &sockAddr, &addrlen);
73
74 std::memcpy(&respAddr, &sockAddr, sizeof(respAddr));
75
76 if (receivedLength <= 0)
77 {
78 lg2::error("MctpRequester: Failed to receive message");
79 callback(-2, std::vector<uint8_t>{});
80 return;
81 }
82
83 if (respAddr.smctp_type != msgType)
84 {
85 lg2::error("MctpRequester: Message type mismatch");
86 callback(-3, std::move(fullRespMsg));
87 return;
88 }
89
90 mctp_eid_t respEid = respAddr.smctp_addr.s_addr;
91
92 if (respEid != eid)
93 {
94 lg2::error(
95 "MctpRequester: EID mismatch - expected={EID}, received={REID}",
96 "EID", eid, "REID", respEid);
97 callback(-4, std::move(fullRespMsg));
98 return;
99 }
100
101 if (receivedLength > sizeof(ocp::accelerator_management::BindingPciVid))
102 {
103 ocp::accelerator_management::BindingPciVid reqHdr{};
104 std::memcpy(&reqHdr, reqMsg.data(),
105 sizeof(ocp::accelerator_management::BindingPciVid));
106
107 ocp::accelerator_management::BindingPciVid respHdr{};
108 std::memcpy(&respHdr, fullRespMsg.data(),
109 sizeof(ocp::accelerator_management::BindingPciVid));
110
111 if (reqHdr.instance_id != respHdr.instance_id)
112 {
113 lg2::error(
114 "MctpRequester: Instance ID mismatch - request={REQ}, response={RESP}",
115 "REQ", static_cast<int>(reqHdr.instance_id), "RESP",
116 static_cast<int>(respHdr.instance_id));
117 callback(-5, std::move(fullRespMsg));
118 return;
119 }
120 }
121
122 callback(0, std::move(fullRespMsg));
123}
124
125void MctpRequester::sendRecvMsg(
126 mctp_eid_t eid, const std::vector<uint8_t>& reqMsg,
127 const std::function<void(int, std::vector<uint8_t>)>& callback)
128{
129 std::vector<uint8_t> respMsg{};
130
131 if (reqMsg.size() < sizeof(ocp::accelerator_management::BindingPciVid))
132 {
133 lg2::error("MctpRequester: Message too small");
134 callback(-2, respMsg);
135 return;
136 }
137
138 // Create address structure
139 struct sockaddr sockAddr{};
140 struct sockaddr_mctp addr{};
141 addr.smctp_family = AF_MCTP;
142 addr.smctp_addr.s_addr = eid;
143 addr.smctp_type = msgType;
144 addr.smctp_tag = MCTP_TAG_OWNER;
145
146 std::memcpy(&sockAddr, &addr, sizeof(addr));
147
148 // Send message
149 ssize_t rc = sendto(sockfd, reqMsg.data(), reqMsg.size(), 0, &sockAddr,
150 sizeof(addr));
151 if (rc < 0)
152 {
153 lg2::error(
154 "MctpRequester failed send data to the MCTP Socket - Error={EC}.",
155 "EC", rc);
156 callback(rc, respMsg);
157 return;
158 }
159
160 // Set up async receive with timeout
161 auto timer = std::make_shared<boost::asio::steady_timer>(ctx);
162 timer->expires_after(2s);
163
164 // Set up handler for when the timer expires
165 timer->async_wait([callback, timer](const boost::system::error_code& ec) {
166 if (ec != boost::asio::error::operation_aborted)
167 {
168 callback(-1, std::vector<uint8_t>{});
169 }
170 });
171
172 // Set up asynchronous receive
173 mctpSocket.async_receive(
174 boost::asio::buffer(respMsg), MSG_PEEK | MSG_TRUNC,
175 [this, eid, reqMsg, callback,
176 timer](const boost::system::error_code& ec, size_t peekedLength) {
177 // Cancel the timer since we got a response
178 timer->cancel();
179
180 if (ec)
181 {
182 lg2::error(
183 "MctpRequester failed to receive data from the MCTP socket - ErrorCode={EC}, Error={ER}.",
184 "EC", ec.value(), "ER", ec.message());
185 callback(-1, std::vector<uint8_t>{});
186 return;
187 }
188
189 this->processRecvMsg(eid, reqMsg, callback, peekedLength);
190 });
191}
192
193} // namespace mctp