net/addr/subnet: Add class for representing a network subnet

This makes it possible to represent and IPv4/IPv6 CIDR network segment
via a class. Provides useful functions to compute base address and
determine if addresses are inside the subnet.

Change-Id: Ib9d01e28b6c8a28ccb622fef87b217fc96daf905
Signed-off-by: William A. Kennington III <wak@google.com>
diff --git a/include/meson.build b/include/meson.build
index 1402818..d7cf82b 100644
--- a/include/meson.build
+++ b/include/meson.build
@@ -11,6 +11,7 @@
   'stdplus/hash/tuple.hpp',
   'stdplus/net/addr/ether.hpp',
   'stdplus/net/addr/ip.hpp',
+  'stdplus/net/addr/subnet.hpp',
   'stdplus/numeric/endian.hpp',
   'stdplus/pinned.hpp',
   'stdplus/raw.hpp',
diff --git a/include/stdplus/net/addr/subnet.hpp b/include/stdplus/net/addr/subnet.hpp
new file mode 100644
index 0000000..862237b
--- /dev/null
+++ b/include/stdplus/net/addr/subnet.hpp
@@ -0,0 +1,194 @@
+#include <stdplus/net/addr/ip.hpp>
+#include <stdplus/numeric/endian.hpp>
+
+#include <limits>
+#include <type_traits>
+
+namespace stdplus
+{
+namespace detail
+{
+
+// AddressSan doesn't understand our masking of shift UB
+__attribute__((no_sanitize("undefined"))) constexpr uint32_t
+    addr32Mask(std::ptrdiff_t pfx) noexcept
+{
+    // Positive prefix check + mask to handle UB when the left shift becomes
+    // more than 31 bits
+    return hton(static_cast<uint32_t>(-int32_t{pfx > 0}) & ~uint32_t{0}
+                                                               << (32 - pfx));
+}
+
+constexpr In4Addr addrToSubnet(In4Addr a, std::size_t pfx) noexcept
+{
+    return In4Addr{in_addr{a.s4_addr32 & addr32Mask(pfx)}};
+}
+
+constexpr In6Addr addrToSubnet(In6Addr a, std::size_t pfx, std::size_t i = 0,
+                               std::size_t s = 0, In6Addr ret = {}) noexcept
+{
+    if (s + 32 < pfx)
+    {
+        ret.s6_addr32[i] = a.s6_addr32[i];
+        return addrToSubnet(a, pfx, i + 1, s + 32, ret);
+    }
+    ret.s6_addr32[i] = a.s6_addr32[i] & addr32Mask(pfx - s);
+    return ret;
+}
+
+constexpr InAnyAddr addrToSubnet(InAnyAddr a, std::size_t pfx) noexcept
+{
+    return std::visit([&](auto av) { return InAnyAddr{addrToSubnet(av, pfx)}; },
+                      a);
+}
+
+constexpr bool subnetContains(auto, auto, std::size_t) noexcept
+{
+    return false;
+}
+
+template <typename T>
+constexpr bool subnetContains(T l, T r, std::size_t pfx) noexcept
+{
+    return addrToSubnet(l, pfx) == addrToSubnet(r, pfx);
+}
+
+constexpr bool subnetContains(InAnyAddr l, auto r, std::size_t pfx) noexcept
+{
+    return std::visit([&](auto v) { return detail::subnetContains(v, r, pfx); },
+                      l);
+}
+
+constexpr std::size_t addrBits(auto a) noexcept
+{
+    return sizeof(a) << 3;
+}
+
+void invalidSubnetPfx(std::size_t pfx);
+
+template <typename Addr, typename Pfx>
+class Subnet46
+{
+  private:
+    static constexpr inline std::size_t maxPfx = sizeof(Addr) * 8;
+    static_assert(std::is_unsigned_v<Pfx> && std::is_integral_v<Pfx>);
+    static_assert(std::numeric_limits<Pfx>::max() >= maxPfx);
+
+    Addr addr;
+    Pfx pfx;
+
+  public:
+    constexpr Subnet46(Addr addr, Pfx pfx) : addr(addr), pfx(pfx)
+    {
+        if (addrBits(addr) < pfx)
+        {
+            invalidSubnetPfx(pfx);
+        }
+    }
+
+    constexpr auto getAddr() const noexcept
+    {
+        return addr;
+    }
+
+    constexpr auto getPfx() const noexcept
+    {
+        return pfx;
+    }
+
+    constexpr bool operator==(Subnet46 rhs) const noexcept
+    {
+        return addr == rhs.addr && pfx == rhs.pfx;
+    }
+
+    constexpr Addr network() const noexcept
+    {
+        return addrToSubnet(addr, pfx);
+    }
+
+    constexpr bool contains(Addr addr) const noexcept
+    {
+        return addrToSubnet(this->addr, pfx) == addrToSubnet(addr, pfx);
+    }
+};
+
+} // namespace detail
+
+using Subnet4 = detail::Subnet46<In4Addr, uint8_t>;
+using Subnet6 = detail::Subnet46<In6Addr, uint8_t>;
+
+class SubnetAny
+{
+  private:
+    InAnyAddr addr;
+    uint8_t pfx;
+
+  public:
+    constexpr SubnetAny(auto addr, uint8_t pfx) : addr(addr), pfx(pfx)
+    {
+        if (detail::addrBits(addr) < pfx)
+        {
+            detail::invalidSubnetPfx(pfx);
+        }
+    }
+    constexpr SubnetAny(InAnyAddr addr, uint8_t pfx) : addr(addr), pfx(pfx)
+    {
+        if (std::visit([](auto v) { return detail::addrBits(v); }, addr) < pfx)
+        {
+            detail::invalidSubnetPfx(pfx);
+        }
+    }
+
+    template <typename T, typename S>
+    constexpr SubnetAny(detail::Subnet46<T, S> o) noexcept :
+        addr(o.getAddr()), pfx(o.getPfx())
+    {}
+
+    constexpr auto getAddr() const noexcept
+    {
+        return addr;
+    }
+
+    constexpr auto getPfx() const noexcept
+    {
+        return pfx;
+    }
+
+    template <typename T, typename S>
+    constexpr bool operator==(detail::Subnet46<T, S> rhs) const noexcept
+    {
+        return addr == rhs.getAddr() && pfx == rhs.getPfx();
+    }
+    constexpr bool operator==(SubnetAny rhs) const noexcept
+    {
+        return addr == rhs.addr && pfx == rhs.pfx;
+    }
+
+    constexpr InAnyAddr network() const noexcept
+    {
+        return detail::addrToSubnet(addr, pfx);
+    }
+
+    constexpr bool contains(In4Addr addr) const noexcept
+    {
+        return detail::subnetContains(this->addr, addr, pfx);
+    }
+    constexpr bool contains(in_addr addr) const noexcept
+    {
+        return contains(In4Addr{addr});
+    }
+    constexpr bool contains(In6Addr addr) const noexcept
+    {
+        return detail::subnetContains(this->addr, addr, pfx);
+    }
+    constexpr bool contains(in6_addr addr) const noexcept
+    {
+        return contains(In6Addr{addr});
+    }
+    constexpr bool contains(InAnyAddr addr) const noexcept
+    {
+        return std::visit([&](auto v) { return contains(v); }, addr);
+    }
+};
+
+} // namespace stdplus
diff --git a/src/meson.build b/src/meson.build
index dd67b35..47fcd33 100644
--- a/src/meson.build
+++ b/src/meson.build
@@ -52,6 +52,7 @@
   'hash/tuple.cpp',
   'net/addr/ether.cpp',
   'net/addr/ip.cpp',
+  'net/addr/subnet.cpp',
   'numeric/endian.cpp',
   'pinned.cpp',
   'raw.cpp',
diff --git a/src/net/addr/subnet.cpp b/src/net/addr/subnet.cpp
new file mode 100644
index 0000000..36d6049
--- /dev/null
+++ b/src/net/addr/subnet.cpp
@@ -0,0 +1,18 @@
+#include <fmt/format.h>
+
+#include <stdplus/net/addr/subnet.hpp>
+
+#include <stdexcept>
+
+namespace stdplus::detail
+{
+
+void invalidSubnetPfx(std::size_t pfx)
+{
+    throw std::invalid_argument(fmt::format("Invalid subnet prefix {}", pfx));
+}
+
+template class Subnet46<In4Addr, uint8_t>;
+template class Subnet46<In6Addr, uint8_t>;
+
+} // namespace stdplus::detail
diff --git a/test/meson.build b/test/meson.build
index 9c0580b..01104c5 100644
--- a/test/meson.build
+++ b/test/meson.build
@@ -8,6 +8,7 @@
   'hash/tuple': [stdplus_dep, gtest_main_dep],
   'net/addr/ether': [stdplus_dep, gtest_main_dep],
   'net/addr/ip': [stdplus_dep, gtest_main_dep],
+  'net/addr/subnet': [stdplus_dep, gtest_main_dep],
   'numeric/endian': [stdplus_dep, gtest_main_dep],
   'pinned': [stdplus_dep, gtest_main_dep],
   'raw': [stdplus_dep, gmock_dep, gtest_main_dep],
diff --git a/test/net/addr/subnet.cpp b/test/net/addr/subnet.cpp
new file mode 100644
index 0000000..561507c
--- /dev/null
+++ b/test/net/addr/subnet.cpp
@@ -0,0 +1,160 @@
+#include <fmt/format.h>
+
+#include <stdplus/net/addr/subnet.hpp>
+
+#include <gtest/gtest.h>
+
+namespace stdplus
+{
+
+auto addr4Full = In4Addr{255, 255, 255, 255};
+auto addr6Full = In6Addr{255, 255, 255, 255, 255, 255, 255, 255,
+                         255, 255, 255, 255, 255, 255, 255, 255};
+
+TEST(Subnet4, Basic)
+{
+    EXPECT_NO_THROW(Subnet4(in_addr{0xffffffff}, 32));
+    EXPECT_NO_THROW(Subnet4(addr4Full, 0));
+    EXPECT_NO_THROW(Subnet4(in_addr{}, 10));
+    EXPECT_THROW(Subnet4(in_addr{0xffffffff}, 33), std::invalid_argument);
+    EXPECT_THROW(Subnet4(in_addr{0xffffffff}, 64), std::invalid_argument);
+
+    EXPECT_NE(Subnet4(in_addr{0xff}, 32), Subnet4(in_addr{}, 32));
+    EXPECT_NE(Subnet4(in_addr{0xff}, 26), Subnet4(in_addr{0xff}, 27));
+    EXPECT_EQ(Subnet4(in_addr{0xff}, 32), Subnet4(in_addr{0xff}, 32));
+    EXPECT_EQ(Subnet4(in_addr{}, 1), Subnet4(in_addr{}, 1));
+}
+
+TEST(Subnet4, Network)
+{
+    EXPECT_EQ((In4Addr{}), Subnet4(In4Addr{}, 32).network());
+    EXPECT_EQ(addr4Full, Subnet4(addr4Full, 32).network());
+    EXPECT_EQ((In4Addr{255, 255, 128, 0}), Subnet4(addr4Full, 17).network());
+    EXPECT_EQ((In4Addr{255, 255, 0, 0}), Subnet4(addr4Full, 16).network());
+    EXPECT_EQ((In4Addr{255, 254, 0, 0}), Subnet4(addr4Full, 15).network());
+    EXPECT_EQ((In4Addr{}), Subnet4(addr4Full, 0).network());
+    EXPECT_EQ((In4Addr{}), Subnet4(In4Addr{}, 0).network());
+}
+
+TEST(Subnet4, Contains)
+{
+    EXPECT_TRUE(Subnet4(addr4Full, 32).contains(addr4Full));
+    EXPECT_FALSE(Subnet4(addr4Full, 32).contains(In4Addr{255, 255, 255, 254}));
+    EXPECT_FALSE(Subnet4(addr4Full, 32).contains(In4Addr{}));
+    EXPECT_TRUE(
+        Subnet4(addr4Full, 17).contains(static_cast<in_addr>(addr4Full)));
+    EXPECT_TRUE(Subnet4(addr4Full, 17).contains(In4Addr{255, 255, 128, 134}));
+    EXPECT_FALSE(Subnet4(addr4Full, 17).contains(In4Addr{255, 255, 127, 132}));
+    EXPECT_TRUE(Subnet4(addr4Full, 14).contains(addr4Full));
+    EXPECT_TRUE(Subnet4(addr4Full, 0).contains(addr4Full));
+    EXPECT_TRUE(Subnet4(In4Addr{}, 0).contains(addr4Full));
+    EXPECT_TRUE(Subnet4(addr4Full, 0).contains(In4Addr{}));
+}
+
+TEST(Subnet6, Basic)
+{
+    EXPECT_NO_THROW(Subnet6(in6_addr{0xff}, 128));
+    EXPECT_NO_THROW(Subnet6(addr6Full, 0));
+    EXPECT_NO_THROW(Subnet6(in6_addr{}, 65));
+    EXPECT_THROW(Subnet6(in6_addr{0xff}, 129), std::invalid_argument);
+    EXPECT_THROW(Subnet6(in6_addr{0xff}, 150), std::invalid_argument);
+
+    EXPECT_NE(Subnet6(in6_addr{0xff}, 32), Subnet6(in6_addr{}, 32));
+    EXPECT_NE(Subnet6(in6_addr{0xff}, 26), Subnet6(in6_addr{0xff}, 27));
+    EXPECT_EQ(Subnet6(in6_addr{0xff}, 32), Subnet6(in6_addr{0xff}, 32));
+    EXPECT_EQ(Subnet6(in6_addr{}, 1), Subnet6(in6_addr{}, 1));
+}
+
+TEST(Subnet6, Network)
+{
+    EXPECT_EQ(In6Addr(), Subnet6(In6Addr(), 128).network());
+    EXPECT_EQ(addr6Full, Subnet6(addr6Full, 128).network());
+    EXPECT_EQ((In6Addr{255, 255, 255, 255, 224}),
+              Subnet6(addr6Full, 35).network());
+    EXPECT_EQ((In6Addr{255, 255, 255, 255}), Subnet6(addr6Full, 32).network());
+    EXPECT_EQ((In6Addr{255, 255, 128, 0}), Subnet6(addr6Full, 17).network());
+    EXPECT_EQ((In6Addr{255, 255, 0, 0}), Subnet6(addr6Full, 16).network());
+    EXPECT_EQ((In6Addr{255, 254, 0, 0}), Subnet6(addr6Full, 15).network());
+    EXPECT_EQ((In6Addr{}), Subnet6(addr6Full, 0).network());
+    EXPECT_EQ((In6Addr{}), Subnet6(In6Addr{}, 0).network());
+}
+
+TEST(Subnet6, Contains)
+{
+    auto addr6NFull = addr6Full;
+    addr6NFull.s6_addr[15] = 254;
+    EXPECT_TRUE(Subnet6(addr6Full, 128).contains(addr6Full));
+    EXPECT_FALSE(Subnet6(addr6Full, 128).contains(addr6NFull));
+    EXPECT_FALSE(Subnet6(addr6Full, 128).contains(In6Addr{}));
+    EXPECT_TRUE(
+        Subnet6(addr6Full, 127).contains(static_cast<in6_addr>(addr6Full)));
+    EXPECT_TRUE(Subnet6(addr6Full, 127).contains(addr6NFull));
+    EXPECT_TRUE(
+        Subnet6(addr6Full, 33).contains(In6Addr{255, 255, 255, 255, 128, 255}));
+    EXPECT_FALSE(
+        Subnet6(addr6Full, 33).contains(In6Addr{255, 255, 255, 255, 127}));
+    EXPECT_TRUE(Subnet6(In6Addr{}, 33).contains(In6Addr{0, 0, 0, 0, 127}));
+    EXPECT_TRUE(
+        Subnet6(addr6Full, 14).contains(In6Addr{255, 255, 0, 0, 0, 0, 0, 145}));
+    EXPECT_FALSE(Subnet6(addr6Full, 14).contains(In6Addr{255, 127, 1}));
+    EXPECT_TRUE(Subnet6(addr6Full, 0).contains(addr6Full));
+    EXPECT_TRUE(Subnet6(In6Addr{}, 0).contains(addr6Full));
+    EXPECT_TRUE(Subnet6(addr6Full, 0).contains(In6Addr{}));
+}
+
+TEST(SubnetAny, Basic)
+{
+    EXPECT_NO_THROW(SubnetAny(in_addr{0xffffffff}, 32));
+    EXPECT_NO_THROW(SubnetAny(addr4Full, 0));
+    EXPECT_NO_THROW(SubnetAny(InAnyAddr{addr4Full}, 0));
+    EXPECT_NO_THROW(SubnetAny(in_addr{}, 10));
+    EXPECT_THROW(SubnetAny(in_addr{0xffffffff}, 33), std::invalid_argument);
+    EXPECT_THROW(SubnetAny(InAnyAddr{in_addr{0xffffffff}}, 33),
+                 std::invalid_argument);
+    EXPECT_THROW(SubnetAny(in_addr{0xffffffff}, 64), std::invalid_argument);
+    EXPECT_NO_THROW(SubnetAny(in6_addr{0xff}, 128));
+    EXPECT_NO_THROW(SubnetAny(addr6Full, 0));
+    EXPECT_NO_THROW(SubnetAny(InAnyAddr{addr6Full}, 0));
+    EXPECT_NO_THROW(SubnetAny(in6_addr{}, 65));
+    EXPECT_THROW(SubnetAny(in6_addr{0xff}, 129), std::invalid_argument);
+    EXPECT_THROW(SubnetAny(InAnyAddr{in6_addr{0xff}}, 129),
+                 std::invalid_argument);
+    EXPECT_THROW(SubnetAny(in6_addr{0xff}, 150), std::invalid_argument);
+
+    EXPECT_NO_THROW(SubnetAny(Subnet4(in_addr{}, 32)));
+    EXPECT_NO_THROW(SubnetAny(Subnet6(in6_addr{0xff}, 128)));
+
+    EXPECT_NE(SubnetAny(in6_addr{0xff}, 32), Subnet6(in6_addr{}, 32));
+    EXPECT_NE(Subnet6(in6_addr{0xff}, 26), SubnetAny(in6_addr{0xff}, 27));
+    EXPECT_EQ(SubnetAny(in6_addr{0xff}, 32), Subnet6(in6_addr{0xff}, 32));
+    EXPECT_EQ(SubnetAny(in6_addr{0xff}, 32), SubnetAny(in6_addr{0xff}, 32));
+    EXPECT_NE(SubnetAny(in6_addr{0xff}, 32), Subnet4(in_addr{0xff}, 32));
+    EXPECT_NE(SubnetAny(in_addr{0xff}, 32), Subnet4(in_addr{}, 32));
+    EXPECT_NE(Subnet4(in_addr{0xff}, 26), SubnetAny(in_addr{0xff}, 27));
+    EXPECT_EQ(SubnetAny(in_addr{0xff}, 32), Subnet4(in_addr{0xff}, 32));
+    EXPECT_EQ(SubnetAny(in_addr{0xff}, 32), SubnetAny(in_addr{0xff}, 32));
+}
+
+TEST(SubnetAny, Network)
+{
+    EXPECT_EQ(In6Addr(), SubnetAny(In6Addr(), 128).network());
+    EXPECT_EQ(addr6Full, SubnetAny(addr6Full, 128).network());
+    EXPECT_EQ(In6Addr(), SubnetAny(addr6Full, 0).network());
+    EXPECT_EQ(In4Addr(), SubnetAny(In4Addr(), 32).network());
+    EXPECT_EQ(addr4Full, SubnetAny(addr4Full, 32).network());
+    EXPECT_EQ(In4Addr(), SubnetAny(addr4Full, 0).network());
+}
+
+TEST(SubnetAny, Contains)
+{
+    EXPECT_TRUE(SubnetAny(addr6Full, 128).contains(addr6Full));
+    EXPECT_TRUE(SubnetAny(addr6Full, 128).contains(InAnyAddr{addr6Full}));
+    EXPECT_FALSE(SubnetAny(addr6Full, 128).contains(in6_addr{}));
+    EXPECT_FALSE(SubnetAny(addr6Full, 128).contains(InAnyAddr(in6_addr{})));
+    EXPECT_TRUE(SubnetAny(addr4Full, 32).contains(addr4Full));
+    EXPECT_TRUE(SubnetAny(addr4Full, 32).contains(InAnyAddr{addr4Full}));
+    EXPECT_FALSE(SubnetAny(addr4Full, 32).contains(in_addr{}));
+    EXPECT_FALSE(SubnetAny(addr4Full, 32).contains(InAnyAddr{In4Addr{}}));
+}
+
+} // namespace stdplus