test: Mock rtnetlink calls

Right now this does nothing but ACK every sent message so that our
interfaces end up getting no conflicting data from the system.

Change-Id: I226fdafca6e799ae1e9810fb928b4270a61604f9
Signed-off-by: William A. Kennington III <wak@google.com>
diff --git a/test/mock_syscall.cpp b/test/mock_syscall.cpp
index 021c304..280c74d 100644
--- a/test/mock_syscall.cpp
+++ b/test/mock_syscall.cpp
@@ -1,18 +1,25 @@
 #include <arpa/inet.h>
 #include <dlfcn.h>
 #include <ifaddrs.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
 #include <net/ethernet.h>
 #include <net/if.h>
 #include <netinet/in.h>
 #include <sys/ioctl.h>
 #include <sys/socket.h>
 #include <sys/types.h>
+#include <unistd.h>
 
 #include <cstdarg>
+#include <cstdio>
 #include <cstring>
 #include <map>
+#include <queue>
 #include <stdexcept>
 #include <string>
+#include <string_view>
+#include <vector>
 
 #define MAX_IFADDRS 5
 
@@ -37,6 +44,8 @@
     return;
 }
 
+std::map<int, std::queue<std::string>> mock_rtnetlinks;
+
 std::map<std::string, int> mock_if_nametoindex;
 std::map<int, std::string> mock_if_indextoname;
 std::map<std::string, ether_addr> mock_macs;
@@ -45,6 +54,7 @@
 {
     mock_ifaddrs = nullptr;
     ifaddr_count = 0;
+    mock_rtnetlinks.clear();
     mock_if_nametoindex.clear();
     mock_if_indextoname.clear();
     mock_macs.clear();
@@ -93,6 +103,38 @@
     mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr;
 }
 
+void validateMsgHdr(const struct msghdr* msg)
+{
+    if (msg->msg_namelen != sizeof(sockaddr_nl))
+    {
+        fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen);
+        abort();
+    }
+    const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name);
+    if (from.nl_family != AF_NETLINK)
+    {
+        fprintf(stderr, "recvmsg bad family data\n");
+        abort();
+    }
+    if (msg->msg_iovlen != 1)
+    {
+        fprintf(stderr, "recvmsg unsupported iov configuration\n");
+        abort();
+    }
+}
+
+ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in)
+{
+    nlmsgerr ack{};
+    nlmsghdr hdr{};
+    hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack));
+    hdr.nlmsg_type = NLMSG_ERROR;
+    auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
+    memcpy(out.data(), &hdr, sizeof(hdr));
+    memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack));
+    return in.size();
+}
+
 extern "C" {
 
 int getifaddrs(ifaddrs** ifap)
@@ -155,4 +197,108 @@
     return real_ioctl(fd, request, data);
 }
 
+int socket(int domain, int type, int protocol)
+{
+    static auto real_socket =
+        reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket"));
+    int fd = real_socket(domain, type, protocol);
+    if (domain == AF_NETLINK && !(type & SOCK_RAW))
+    {
+        fprintf(stderr, "Netlink sockets must be RAW\n");
+        abort();
+    }
+    if (domain == AF_NETLINK && protocol == NETLINK_ROUTE)
+    {
+        mock_rtnetlinks[fd] = {};
+    }
+    return fd;
+}
+
+int close(int fd)
+{
+    auto it = mock_rtnetlinks.find(fd);
+    if (it != mock_rtnetlinks.end())
+    {
+        mock_rtnetlinks.erase(it);
+    }
+
+    static auto real_close =
+        reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close"));
+    return real_close(fd);
+}
+
+ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags)
+{
+    auto it = mock_rtnetlinks.find(sockfd);
+    if (it == mock_rtnetlinks.end())
+    {
+        static auto real_sendmsg =
+            reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg"));
+        return real_sendmsg(sockfd, msg, flags);
+    }
+    auto& msgs = it->second;
+
+    validateMsgHdr(msg);
+    if (!msgs.empty())
+    {
+        fprintf(stderr, "Unread netlink responses\n");
+        abort();
+    }
+
+    ssize_t ret;
+    std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base),
+                         msg->msg_iov[0].iov_len);
+
+    ret = sendmsg_ack(msgs, iov);
+    if (ret != 0)
+    {
+        return ret;
+    }
+
+    errno = ENOSYS;
+    return -1;
+}
+
+ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags)
+{
+    auto it = mock_rtnetlinks.find(sockfd);
+    if (it == mock_rtnetlinks.end())
+    {
+        static auto real_recvmsg =
+            reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg"));
+        return real_recvmsg(sockfd, msg, flags);
+    }
+    auto& msgs = it->second;
+
+    validateMsgHdr(msg);
+    constexpr size_t required_buf_size = 8192;
+    if (msg->msg_iov[0].iov_len < required_buf_size)
+    {
+        fprintf(stderr, "recvmsg iov too short: %zu\n",
+                msg->msg_iov[0].iov_len);
+        abort();
+    }
+    if (msgs.empty())
+    {
+        fprintf(stderr, "No pending netlink responses\n");
+        abort();
+    }
+
+    ssize_t ret = 0;
+    auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base);
+    while (!msgs.empty())
+    {
+        const auto& msg = msgs.front();
+        if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size)
+        {
+            break;
+        }
+        ret = NLMSG_ALIGN(ret);
+        memcpy(data + ret, msg.data(), msg.size());
+        ret += msg.size();
+        msgs.pop();
+    }
+    return ret;
+}
+
 } // extern "C"