Check dictionary sizes to avoid any heap overflows

Tested:
Unit tested

Signed-off-by: Kasun Athukorala <kasunath@google.com>
Change-Id: I34972b7f4daf0e818461ef2a4842966a247fc20e
diff --git a/src/bej_decoder_core.c b/src/bej_decoder_core.c
index 43f5b69..73f67c9 100644
--- a/src/bej_decoder_core.c
+++ b/src/bej_decoder_core.c
@@ -918,6 +918,44 @@
         return bejErrorNotSupported;
     }
 
+    const struct BejDictionaryHeader* schemaDictionaryHeader =
+        ((const struct BejDictionaryHeader*)dictionaries->schemaDictionary);
+    if (schemaDictionaryHeader->dictionarySize !=
+        dictionaries->schemaDictionarySize)
+    {
+        fprintf(stderr, "Invalid schema dictionary size: %u. Expected: %u.\n",
+                schemaDictionaryHeader->dictionarySize,
+                dictionaries->schemaDictionarySize);
+        return bejErrorInvalidSize;
+    }
+
+    const struct BejDictionaryHeader* annotationDictionaryHeader =
+        ((const struct BejDictionaryHeader*)dictionaries->annotationDictionary);
+    if (annotationDictionaryHeader->dictionarySize !=
+        dictionaries->annotationDictionarySize)
+    {
+        fprintf(stderr,
+                "Invalid annotation dictionary size: %u. Expected: %u.\n",
+                annotationDictionaryHeader->dictionarySize,
+                dictionaries->annotationDictionarySize);
+        return bejErrorInvalidSize;
+    }
+
+    if (dictionaries->errorDictionary != NULL)
+    {
+        const struct BejDictionaryHeader* errorDictionaryHeader =
+            ((const struct BejDictionaryHeader*)dictionaries->errorDictionary);
+        if (errorDictionaryHeader->dictionarySize !=
+            dictionaries->errorDictionarySize)
+        {
+            fprintf(stderr,
+                    "Invalid error dictionary size: %u. Expected: %u.\n",
+                    errorDictionaryHeader->dictionarySize,
+                    dictionaries->errorDictionarySize);
+            return bejErrorInvalidSize;
+        }
+    }
+
     // Skip the PLDM header.
     const uint8_t* enStream = encodedPldmBlock + pldmHeaderSize;
     uint32_t streamLen = blockLength - pldmHeaderSize;