raw: Add strict copy/ref functions
These don't allow truncation of the input
Change-Id: I48e442c0c2e22ea3c2bd5eaf2808408af5d7be57
Signed-off-by: William A. Kennington III <wak@google.com>
diff --git a/include/stdplus/raw.hpp b/include/stdplus/raw.hpp
index 7da4b92..4f5b03f 100644
--- a/include/stdplus/raw.hpp
+++ b/include/stdplus/raw.hpp
@@ -66,42 +66,50 @@
* @param[in] data - The data buffer being copied from
* @return The copyable type with data populated
*/
-template <typename T, typename Container>
-T copyFrom(const Container& c)
-{
- static_assert(std::is_trivially_copyable_v<T>);
- static_assert(detail::trivialContainer<Container>);
- T ret;
- const size_t bytes = std::size(c) * sizeof(*std::data(c));
- if (bytes < sizeof(ret))
- {
- throw std::runtime_error(
- fmt::format("CopyFrom: {} < {}", bytes, sizeof(ret)));
+#define STDPLUS_COPY_FROM(func, comp) \
+ template <typename T, typename Container> \
+ T func(const Container& c) \
+ { \
+ static_assert(std::is_trivially_copyable_v<T>); \
+ static_assert(detail::trivialContainer<Container>); \
+ T ret; \
+ const size_t bytes = std::size(c) * sizeof(*std::data(c)); \
+ if (bytes comp sizeof(ret)) \
+ { \
+ throw std::runtime_error( \
+ fmt::format(#func ": {} < {}", bytes, sizeof(ret))); \
+ } \
+ std::memcpy(&ret, std::data(c), sizeof(ret)); \
+ return ret; \
}
- std::memcpy(&ret, std::data(c), sizeof(ret));
- return ret;
-}
+STDPLUS_COPY_FROM(copyFrom, <)
+STDPLUS_COPY_FROM(copyFromStrict, !=)
+#undef STDPLUS_COPY_FROM
/** @brief References the data from a buffer if aligned
*
* @param[in] data - The data buffer being referenced
* @return The reference to the data in the new type
*/
-template <typename T, typename Container,
- typename Tp = detail::copyConst<T, detail::dataType<Container>>>
-Tp& refFrom(Container&& c)
-{
- static_assert(std::is_trivially_copyable_v<Tp>);
- static_assert(detail::trivialContainer<Container>);
- static_assert(sizeof(*std::data(c)) % alignof(Tp) == 0);
- const size_t bytes = std::size(c) * sizeof(*std::data(c));
- if (bytes < sizeof(Tp))
- {
- throw std::runtime_error(
- fmt::format("RefFrom: {} < {}", bytes, sizeof(Tp)));
+#define STDPLUS_REF_FROM(func, comp) \
+ template <typename T, typename Container, \
+ typename Tp = detail::copyConst<T, detail::dataType<Container>>> \
+ Tp& func(Container&& c) \
+ { \
+ static_assert(std::is_trivially_copyable_v<Tp>); \
+ static_assert(detail::trivialContainer<Container>); \
+ static_assert(sizeof(*std::data(c)) % alignof(Tp) == 0); \
+ const size_t bytes = std::size(c) * sizeof(*std::data(c)); \
+ if (bytes comp sizeof(Tp)) \
+ { \
+ throw std::runtime_error( \
+ fmt::format(#func ": {} < {}", bytes, sizeof(Tp))); \
+ } \
+ return *reinterpret_cast<Tp*>(std::data(c)); \
}
- return *reinterpret_cast<Tp*>(std::data(c));
-}
+STDPLUS_REF_FROM(refFrom, <)
+STDPLUS_REF_FROM(refFromStrict, !=)
+#undef STDPLUS_REF_FROM
/** @brief Extracts data from a buffer into a copyable type
* Updates the data buffer to show that data was removed
diff --git a/test/raw.cpp b/test/raw.cpp
index 5248b9f..cb013b6 100644
--- a/test/raw.cpp
+++ b/test/raw.cpp
@@ -30,6 +30,7 @@
EXPECT_THROW(copyFrom<int>(cs), std::runtime_error);
std::string_view s;
EXPECT_THROW(copyFrom<int>(s), std::runtime_error);
+ EXPECT_THROW(copyFromStrict<char>(s), std::runtime_error);
}
TEST(CopyFrom, Basic)
@@ -37,12 +38,14 @@
int a = 4;
const std::string_view s(reinterpret_cast<char*>(&a), sizeof(a));
EXPECT_EQ(a, copyFrom<int>(s));
+ EXPECT_EQ(a, copyFromStrict<int>(s));
}
TEST(CopyFrom, Partial)
{
const std::vector<char> s = {'a', 'b', 'c'};
EXPECT_EQ('a', copyFrom<char>(s));
+ EXPECT_THROW(copyFromStrict<char>(s), std::runtime_error);
const char s2[] = "def";
EXPECT_EQ('d', copyFrom<char>(s2));
}
@@ -63,6 +66,7 @@
EXPECT_THROW(refFrom<Int>(cs), std::runtime_error);
std::string_view s;
EXPECT_THROW(refFrom<Int>(s), std::runtime_error);
+ EXPECT_THROW(refFromStrict<Int>(s), std::runtime_error);
}
TEST(RefFrom, Basic)
@@ -70,12 +74,14 @@
Int a = {4, 0, 0, 4};
const std::string_view s(reinterpret_cast<char*>(&a), sizeof(a));
EXPECT_EQ(a, refFrom<Int>(s));
+ EXPECT_EQ(a, refFromStrict<Int>(s));
}
TEST(RefFrom, Partial)
{
const std::vector<char> s = {'a', 'b', 'c'};
EXPECT_EQ('a', refFrom<char>(s));
+ EXPECT_THROW(refFromStrict<char>(s), std::runtime_error);
const char s2[] = "def";
EXPECT_EQ('d', refFrom<char>(s2));
}