Check dictionary sizes to avoid any heap overflows

Tested:
Unit tested

Signed-off-by: Kasun Athukorala <kasunath@google.com>
Change-Id: I34972b7f4daf0e818461ef2a4842966a247fc20e
diff --git a/include/libbej/bej_common.h b/include/libbej/bej_common.h
index 499fd03..9698765 100644
--- a/include/libbej/bej_common.h
+++ b/include/libbej/bej_common.h
@@ -171,8 +171,11 @@
 struct BejDictionaries
 {
     const uint8_t* schemaDictionary;
+    uint32_t schemaDictionarySize;
     const uint8_t* annotationDictionary;
+    uint32_t annotationDictionarySize;
     const uint8_t* errorDictionary;
+    uint32_t errorDictionarySize;
 };
 
 /**
diff --git a/src/bej_decoder_core.c b/src/bej_decoder_core.c
index 43f5b69..73f67c9 100644
--- a/src/bej_decoder_core.c
+++ b/src/bej_decoder_core.c
@@ -918,6 +918,44 @@
         return bejErrorNotSupported;
     }
 
+    const struct BejDictionaryHeader* schemaDictionaryHeader =
+        ((const struct BejDictionaryHeader*)dictionaries->schemaDictionary);
+    if (schemaDictionaryHeader->dictionarySize !=
+        dictionaries->schemaDictionarySize)
+    {
+        fprintf(stderr, "Invalid schema dictionary size: %u. Expected: %u.\n",
+                schemaDictionaryHeader->dictionarySize,
+                dictionaries->schemaDictionarySize);
+        return bejErrorInvalidSize;
+    }
+
+    const struct BejDictionaryHeader* annotationDictionaryHeader =
+        ((const struct BejDictionaryHeader*)dictionaries->annotationDictionary);
+    if (annotationDictionaryHeader->dictionarySize !=
+        dictionaries->annotationDictionarySize)
+    {
+        fprintf(stderr,
+                "Invalid annotation dictionary size: %u. Expected: %u.\n",
+                annotationDictionaryHeader->dictionarySize,
+                dictionaries->annotationDictionarySize);
+        return bejErrorInvalidSize;
+    }
+
+    if (dictionaries->errorDictionary != NULL)
+    {
+        const struct BejDictionaryHeader* errorDictionaryHeader =
+            ((const struct BejDictionaryHeader*)dictionaries->errorDictionary);
+        if (errorDictionaryHeader->dictionarySize !=
+            dictionaries->errorDictionarySize)
+        {
+            fprintf(stderr,
+                    "Invalid error dictionary size: %u. Expected: %u.\n",
+                    errorDictionaryHeader->dictionarySize,
+                    dictionaries->errorDictionarySize);
+            return bejErrorInvalidSize;
+        }
+    }
+
     // Skip the PLDM header.
     const uint8_t* enStream = encodedPldmBlock + pldmHeaderSize;
     uint32_t streamLen = blockLength - pldmHeaderSize;
diff --git a/test/bej_decoder_test.cpp b/test/bej_decoder_test.cpp
index cf76d2b..f881c3e 100644
--- a/test/bej_decoder_test.cpp
+++ b/test/bej_decoder_test.cpp
@@ -66,8 +66,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     BejDecoderJson decoder;
@@ -108,8 +111,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     // Each array element below consists of a set and two properties, resulting
@@ -167,8 +173,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     auto root = std::make_unique<RedfishPropertyParent>();
@@ -228,8 +237,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     auto root = std::make_unique<RedfishPropertyParent>();
@@ -258,8 +270,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     auto root = std::make_unique<RedfishPropertyParent>();
@@ -282,4 +297,61 @@
                 bejErrorInvalidSize);
 }
 
+TEST(BejDecoderSecurityTest, InvalidSchemaDictionarySize)
+{
+    auto inputsOrErr = loadInputs(dummySimpleTestFiles);
+    ASSERT_TRUE(inputsOrErr);
+
+    BejDictionaries dictionaries = {
+        .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = 10,
+        .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
+        .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
+    };
+
+    BejDecoderJson decoder;
+    EXPECT_THAT(decoder.decode(dictionaries, inputsOrErr->encodedStream),
+                bejErrorInvalidSize);
+}
+
+TEST(BejDecoderSecurityTest, InvalidAnnotationDictionarySize)
+{
+    auto inputsOrErr = loadInputs(dummySimpleTestFiles);
+    ASSERT_TRUE(inputsOrErr);
+
+    BejDictionaries dictionaries = {
+        .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
+        .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = 10,
+        .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
+    };
+
+    BejDecoderJson decoder;
+    EXPECT_THAT(decoder.decode(dictionaries, inputsOrErr->encodedStream),
+                bejErrorInvalidSize);
+}
+
+TEST(BejDecoderSecurityTest, InvalidErrorDictionarySize)
+{
+    auto inputsOrErr = loadInputs(dummySimpleTestFiles);
+    ASSERT_TRUE(inputsOrErr);
+
+    BejDictionaries dictionaries = {
+        .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
+        .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
+        .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = 10,
+    };
+
+    BejDecoderJson decoder;
+    EXPECT_THAT(decoder.decode(dictionaries, inputsOrErr->encodedStream),
+                bejErrorInvalidSize);
+}
+
 } // namespace libbej
diff --git a/test/bej_encoder_test.cpp b/test/bej_encoder_test.cpp
index 413d4e1..113a9de 100644
--- a/test/bej_encoder_test.cpp
+++ b/test/bej_encoder_test.cpp
@@ -344,8 +344,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     std::vector<uint8_t> outputBuffer;
@@ -388,8 +391,11 @@
 
     BejDictionaries dictionaries = {
         .schemaDictionary = inputsOrErr->schemaDictionary,
+        .schemaDictionarySize = inputsOrErr->schemaDictionarySize,
         .annotationDictionary = inputsOrErr->annotationDictionary,
+        .annotationDictionarySize = inputsOrErr->annotationDictionarySize,
         .errorDictionary = inputsOrErr->errorDictionary,
+        .errorDictionarySize = inputsOrErr->errorDictionarySize,
     };
 
     libbej::BejEncoderJson encoder;
diff --git a/test/include/bej_common_test.hpp b/test/include/bej_common_test.hpp
index 74d8c83..9d6165d 100644
--- a/test/include/bej_common_test.hpp
+++ b/test/include/bej_common_test.hpp
@@ -28,8 +28,11 @@
 {
     nlohmann::json expectedJson;
     const uint8_t* schemaDictionary;
+    uint32_t schemaDictionarySize;
     const uint8_t* annotationDictionary;
+    uint32_t annotationDictionarySize;
     const uint8_t* errorDictionary;
+    uint32_t errorDictionarySize;
     std::span<const uint8_t> encodedStream;
 };
 
@@ -65,15 +68,19 @@
     jsonInput >> expJson;
 
     static uint8_t schemaDictBuffer[maxBufferSize];
-    if (readBinaryFile(files.schemaDictionaryFile,
-                       std::span(schemaDictBuffer, maxBufferSize)) == 0)
+    uint32_t schemaDictSize = 0;
+    if ((schemaDictSize =
+             readBinaryFile(files.schemaDictionaryFile,
+                            std::span(schemaDictBuffer, maxBufferSize))) == 0)
     {
         return std::nullopt;
     }
 
     static uint8_t annoDictBuffer[maxBufferSize];
-    if (readBinaryFile(files.annotationDictionaryFile,
-                       std::span(annoDictBuffer, maxBufferSize)) == 0)
+    uint32_t annoDictSize = 0;
+    if ((annoDictSize =
+             readBinaryFile(files.annotationDictionaryFile,
+                            std::span(annoDictBuffer, maxBufferSize))) == 0)
     {
         return std::nullopt;
     }
@@ -87,10 +94,12 @@
     }
 
     static uint8_t errorDict[maxBufferSize];
+    uint32_t errorDictSize = 0;
     if (readErrorDictionary)
     {
-        if (readBinaryFile(files.errorDictionaryFile,
-                           std::span(errorDict, maxBufferSize)) == 0)
+        if ((errorDictSize =
+                 readBinaryFile(files.errorDictionaryFile,
+                                std::span(errorDict, maxBufferSize))) == 0)
         {
             return std::nullopt;
         }
@@ -99,8 +108,11 @@
     BejTestInputs inputs = {
         .expectedJson = expJson,
         .schemaDictionary = schemaDictBuffer,
+        .schemaDictionarySize = schemaDictSize,
         .annotationDictionary = annoDictBuffer,
+        .annotationDictionarySize = annoDictSize,
         .errorDictionary = errorDict,
+        .errorDictionarySize = errorDictSize,
         .encodedStream = std::span(encBuffer, encLen),
     };
     return inputs;