blob: 280c74d13b66255ed32114bd58468578053ad930 [file] [log] [blame]
#include <arpa/inet.h>
#include <dlfcn.h>
#include <ifaddrs.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <netinet/in.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdarg>
#include <cstdio>
#include <cstring>
#include <map>
#include <queue>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
#define MAX_IFADDRS 5
int debugging = false;
/* Data for mocking getifaddrs */
struct ifaddr_storage
{
struct ifaddrs ifaddr;
struct sockaddr_storage addr;
struct sockaddr_storage mask;
struct sockaddr_storage bcast;
} mock_ifaddr_storage[MAX_IFADDRS];
struct ifaddrs* mock_ifaddrs = nullptr;
int ifaddr_count = 0;
/* Stub library functions */
void freeifaddrs(ifaddrs* ifp)
{
return;
}
std::map<int, std::queue<std::string>> mock_rtnetlinks;
std::map<std::string, int> mock_if_nametoindex;
std::map<int, std::string> mock_if_indextoname;
std::map<std::string, ether_addr> mock_macs;
void mock_clear()
{
mock_ifaddrs = nullptr;
ifaddr_count = 0;
mock_rtnetlinks.clear();
mock_if_nametoindex.clear();
mock_if_indextoname.clear();
mock_macs.clear();
}
void mock_addIF(const std::string& name, int idx, const ether_addr& mac)
{
if (idx == 0)
{
throw std::invalid_argument("Bad interface index");
}
mock_if_nametoindex[name] = idx;
mock_if_indextoname[idx] = name;
mock_macs[name] = mac;
}
void mock_addIP(const char* name, const char* addr, const char* mask,
unsigned int flags)
{
struct ifaddrs* ifaddr = &mock_ifaddr_storage[ifaddr_count].ifaddr;
struct sockaddr_in* in =
reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].addr);
struct sockaddr_in* mask_in =
reinterpret_cast<sockaddr_in*>(&mock_ifaddr_storage[ifaddr_count].mask);
in->sin_family = AF_INET;
in->sin_port = 0;
in->sin_addr.s_addr = inet_addr(addr);
mask_in->sin_family = AF_INET;
mask_in->sin_port = 0;
mask_in->sin_addr.s_addr = inet_addr(mask);
ifaddr->ifa_next = nullptr;
ifaddr->ifa_name = const_cast<char*>(name);
ifaddr->ifa_flags = flags;
ifaddr->ifa_addr = reinterpret_cast<struct sockaddr*>(in);
ifaddr->ifa_netmask = reinterpret_cast<struct sockaddr*>(mask_in);
ifaddr->ifa_data = nullptr;
if (ifaddr_count > 0)
mock_ifaddr_storage[ifaddr_count - 1].ifaddr.ifa_next = ifaddr;
ifaddr_count++;
mock_ifaddrs = &mock_ifaddr_storage[0].ifaddr;
}
void validateMsgHdr(const struct msghdr* msg)
{
if (msg->msg_namelen != sizeof(sockaddr_nl))
{
fprintf(stderr, "bad namelen: %u\n", msg->msg_namelen);
abort();
}
const auto& from = *reinterpret_cast<sockaddr_nl*>(msg->msg_name);
if (from.nl_family != AF_NETLINK)
{
fprintf(stderr, "recvmsg bad family data\n");
abort();
}
if (msg->msg_iovlen != 1)
{
fprintf(stderr, "recvmsg unsupported iov configuration\n");
abort();
}
}
ssize_t sendmsg_ack(std::queue<std::string>& msgs, std::string_view in)
{
nlmsgerr ack{};
nlmsghdr hdr{};
hdr.nlmsg_len = NLMSG_LENGTH(sizeof(ack));
hdr.nlmsg_type = NLMSG_ERROR;
auto& out = msgs.emplace(hdr.nlmsg_len, '\0');
memcpy(out.data(), &hdr, sizeof(hdr));
memcpy(NLMSG_DATA(out.data()), &ack, sizeof(ack));
return in.size();
}
extern "C" {
int getifaddrs(ifaddrs** ifap)
{
*ifap = mock_ifaddrs;
if (mock_ifaddrs == nullptr)
return -1;
return (0);
}
unsigned if_nametoindex(const char* ifname)
{
auto it = mock_if_nametoindex.find(ifname);
if (it == mock_if_nametoindex.end())
{
errno = ENXIO;
return 0;
}
return it->second;
}
char* if_indextoname(unsigned ifindex, char* ifname)
{
if (ifindex == 0)
{
errno = ENXIO;
return NULL;
}
auto it = mock_if_indextoname.find(ifindex);
if (it == mock_if_indextoname.end())
{
// TODO: Return ENXIO once other code is mocked out
return std::strcpy(ifname, "invalid");
}
return std::strcpy(ifname, it->second.c_str());
}
int ioctl(int fd, unsigned long int request, ...)
{
va_list vl;
va_start(vl, request);
void* data = va_arg(vl, void*);
va_end(vl);
if (request == SIOCGIFHWADDR)
{
auto req = reinterpret_cast<ifreq*>(data);
auto it = mock_macs.find(req->ifr_name);
if (it == mock_macs.end())
{
errno = ENXIO;
return -1;
}
std::memcpy(req->ifr_hwaddr.sa_data, &it->second, sizeof(it->second));
return 0;
}
static auto real_ioctl =
reinterpret_cast<decltype(&ioctl)>(dlsym(RTLD_NEXT, "ioctl"));
return real_ioctl(fd, request, data);
}
int socket(int domain, int type, int protocol)
{
static auto real_socket =
reinterpret_cast<decltype(&socket)>(dlsym(RTLD_NEXT, "socket"));
int fd = real_socket(domain, type, protocol);
if (domain == AF_NETLINK && !(type & SOCK_RAW))
{
fprintf(stderr, "Netlink sockets must be RAW\n");
abort();
}
if (domain == AF_NETLINK && protocol == NETLINK_ROUTE)
{
mock_rtnetlinks[fd] = {};
}
return fd;
}
int close(int fd)
{
auto it = mock_rtnetlinks.find(fd);
if (it != mock_rtnetlinks.end())
{
mock_rtnetlinks.erase(it);
}
static auto real_close =
reinterpret_cast<decltype(&close)>(dlsym(RTLD_NEXT, "close"));
return real_close(fd);
}
ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags)
{
auto it = mock_rtnetlinks.find(sockfd);
if (it == mock_rtnetlinks.end())
{
static auto real_sendmsg =
reinterpret_cast<decltype(&sendmsg)>(dlsym(RTLD_NEXT, "sendmsg"));
return real_sendmsg(sockfd, msg, flags);
}
auto& msgs = it->second;
validateMsgHdr(msg);
if (!msgs.empty())
{
fprintf(stderr, "Unread netlink responses\n");
abort();
}
ssize_t ret;
std::string_view iov(reinterpret_cast<char*>(msg->msg_iov[0].iov_base),
msg->msg_iov[0].iov_len);
ret = sendmsg_ack(msgs, iov);
if (ret != 0)
{
return ret;
}
errno = ENOSYS;
return -1;
}
ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags)
{
auto it = mock_rtnetlinks.find(sockfd);
if (it == mock_rtnetlinks.end())
{
static auto real_recvmsg =
reinterpret_cast<decltype(&recvmsg)>(dlsym(RTLD_NEXT, "recvmsg"));
return real_recvmsg(sockfd, msg, flags);
}
auto& msgs = it->second;
validateMsgHdr(msg);
constexpr size_t required_buf_size = 8192;
if (msg->msg_iov[0].iov_len < required_buf_size)
{
fprintf(stderr, "recvmsg iov too short: %zu\n",
msg->msg_iov[0].iov_len);
abort();
}
if (msgs.empty())
{
fprintf(stderr, "No pending netlink responses\n");
abort();
}
ssize_t ret = 0;
auto data = reinterpret_cast<char*>(msg->msg_iov[0].iov_base);
while (!msgs.empty())
{
const auto& msg = msgs.front();
if (NLMSG_ALIGN(ret) + msg.size() > required_buf_size)
{
break;
}
ret = NLMSG_ALIGN(ret);
memcpy(data + ret, msg.data(), msg.size());
ret += msg.size();
msgs.pop();
}
return ret;
}
} // extern "C"