test: Mock rtnetlink calls
Right now this does nothing but ACK every sent message so that our
interfaces end up getting no conflicting data from the system.
Change-Id: I226fdafca6e799ae1e9810fb928b4270a61604f9
Signed-off-by: William A. Kennington III <wak@google.com>
diff --git a/test/mock_syscall.cpp b/test/mock_syscall.cpp
index 021c304..280c74d 100644
--- a/test/mock_syscall.cpp
+++ b/test/mock_syscall.cpp
@@ -1,18 +1,25 @@
#include <arpa/inet.h>
#include <dlfcn.h>
#include <ifaddrs.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <netinet/in.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
+#include <unistd.h>
#include <cstdarg>
+#include <cstdio>
#include <cstring>
#include <map>
+#include <queue>
#include <stdexcept>
#include <string>
+#include <string_view>
+#include <vector>
#define MAX_IFADDRS 5
@@ -37,6 +44,8 @@
return;
}
+std::map<int, std::queue<std::string>> mock_rtnetlinks;
+
std::map<std::string, int> mock_if_nametoindex;
std::map<int, std::string> mock_if_indextoname;
std::map<std::string, ether_addr> mock_macs;
@@ -45,6 +54,7 @@
{
mock_ifaddrs = nullptr;
ifaddr_count = 0;
+ mock_rtnetlinks.clear();
mock_if_nametoindex.clear();
mock_if_indextoname.clear();
mock_macs.clear();
@@ -93,6 +103,38 @@
mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr;
}
+void validateMsgHdr(const struct msghdr* msg)
+{
+ if (msg->msg_namelen != sizeof(sockaddr_nl))
+ {
+ fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen);
+ abort();
+ }
+ const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name);
+ if (from.nl_family != AF_NETLINK)
+ {
+ fprintf(stderr, "recvmsg bad family data\n");
+ abort();
+ }
+ if (msg->msg_iovlen != 1)
+ {
+ fprintf(stderr, "recvmsg unsupported iov configuration\n");
+ abort();
+ }
+}
+
+ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in)
+{
+ nlmsgerr ack{};
+ nlmsghdr hdr{};
+ hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack));
+ hdr.nlmsg_type = NLMSG_ERROR;
+ auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
+ memcpy(out.data(), &hdr, sizeof(hdr));
+ memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack));
+ return in.size();
+}
+
extern "C" {
int getifaddrs(ifaddrs** ifap)
@@ -155,4 +197,108 @@
return real_ioctl(fd, request, data);
}
+int socket(int domain, int type, int protocol)
+{
+ static auto real_socket =
+ reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket"));
+ int fd = real_socket(domain, type, protocol);
+ if (domain == AF_NETLINK && !(type & SOCK_RAW))
+ {
+ fprintf(stderr, "Netlink sockets must be RAW\n");
+ abort();
+ }
+ if (domain == AF_NETLINK && protocol == NETLINK_ROUTE)
+ {
+ mock_rtnetlinks[fd] = {};
+ }
+ return fd;
+}
+
+int close(int fd)
+{
+ auto it = mock_rtnetlinks.find(fd);
+ if (it != mock_rtnetlinks.end())
+ {
+ mock_rtnetlinks.erase(it);
+ }
+
+ static auto real_close =
+ reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close"));
+ return real_close(fd);
+}
+
+ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags)
+{
+ auto it = mock_rtnetlinks.find(sockfd);
+ if (it == mock_rtnetlinks.end())
+ {
+ static auto real_sendmsg =
+ reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg"));
+ return real_sendmsg(sockfd, msg, flags);
+ }
+ auto& msgs = it->second;
+
+ validateMsgHdr(msg);
+ if (!msgs.empty())
+ {
+ fprintf(stderr, "Unread netlink responses\n");
+ abort();
+ }
+
+ ssize_t ret;
+ std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base),
+ msg->msg_iov[0].iov_len);
+
+ ret = sendmsg_ack(msgs, iov);
+ if (ret != 0)
+ {
+ return ret;
+ }
+
+ errno = ENOSYS;
+ return -1;
+}
+
+ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags)
+{
+ auto it = mock_rtnetlinks.find(sockfd);
+ if (it == mock_rtnetlinks.end())
+ {
+ static auto real_recvmsg =
+ reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg"));
+ return real_recvmsg(sockfd, msg, flags);
+ }
+ auto& msgs = it->second;
+
+ validateMsgHdr(msg);
+ constexpr size_t required_buf_size = 8192;
+ if (msg->msg_iov[0].iov_len < required_buf_size)
+ {
+ fprintf(stderr, "recvmsg iov too short: %zu\n",
+ msg->msg_iov[0].iov_len);
+ abort();
+ }
+ if (msgs.empty())
+ {
+ fprintf(stderr, "No pending netlink responses\n");
+ abort();
+ }
+
+ ssize_t ret = 0;
+ auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base);
+ while (!msgs.empty())
+ {
+ const auto& msg = msgs.front();
+ if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size)
+ {
+ break;
+ }
+ ret = NLMSG_ALIGN(ret);
+ memcpy(data + ret, msg.data(), msg.size());
+ ret += msg.size();
+ msgs.pop();
+ }
+ return ret;
+}
+
} // extern "C"