main: cleanup command handling

Change-Id: I7b07bb13948607b40365411c16510073d33ba57b
Signed-off-by: Patrick Venture <venture@google.com>
diff --git a/ipmi.cpp b/ipmi.cpp
index b9a8943..fad2049 100644
--- a/ipmi.cpp
+++ b/ipmi.cpp
@@ -20,6 +20,27 @@
 #include "flash-ipmi.hpp"
 #include "ipmi.hpp"
 
+IpmiFlashHandler getCommandHandler(FlashSubCmds command)
+{
+    static const std::unordered_map<FlashSubCmds, IpmiFlashHandler>
+        subHandlers = {
+            {FlashSubCmds::flashStartTransfer, startTransfer},
+            {FlashSubCmds::flashDataBlock, dataBlock},
+            {FlashSubCmds::flashDataFinish, dataFinish},
+            {FlashSubCmds::flashStartHash, startHash},
+            {FlashSubCmds::flashHashData, hashBlock},
+            {FlashSubCmds::flashHashFinish, hashFinish},
+        };
+
+    auto results = subHandlers.find(command);
+    if (results == subHandlers.end())
+    {
+        return nullptr;
+    }
+
+    return results->second;
+}
+
 bool validateRequestLength(FlashSubCmds command, size_t requestLen)
 {
     static const std::unordered_map<FlashSubCmds, size_t> minimumLengths = {
diff --git a/ipmi.hpp b/ipmi.hpp
index 5401108..a892300 100644
--- a/ipmi.hpp
+++ b/ipmi.hpp
@@ -1,9 +1,23 @@
 #pragma once
 
+#include <functional>
+
 #include "host-ipmid/ipmid-api.h"
 
 #include "flash-ipmi.hpp"
 
+using IpmiFlashHandler =
+    std::function<ipmi_ret_t(UpdateInterface* updater, const uint8_t* reqBuf,
+                             uint8_t* replyBuf, size_t* dataLen)>;
+
+/**
+ * Retrieve the IPMI command handler.
+ *
+ * @param[in] subcommand - the command
+ * @return the function to call or nullptr on error.
+ */
+IpmiFlashHandler getCommandHandler(FlashSubCmds command);
+
 /**
  * Validate the minimum request length if there is one.
  *
diff --git a/main.cpp b/main.cpp
index 4d2e4bf..0c65f85 100644
--- a/main.cpp
+++ b/main.cpp
@@ -38,8 +38,10 @@
 static ipmi_ret_t flashControl(ipmi_cmd_t cmd, const uint8_t* reqBuf,
                                uint8_t* replyCmdBuf, size_t* dataLen)
 {
+    size_t requestLength = (*dataLen);
+
     /* Verify it's at least as long as the shortest message. */
-    if ((*dataLen) < 1)
+    if (requestLength < 1)
     {
         return IPMI_CC_INVALID;
     }
@@ -47,36 +49,16 @@
     auto subCmd = static_cast<FlashSubCmds>(reqBuf[0]);
 
     /* Validate the minimum request length for the command. */
-    if (!validateRequestLength(subCmd, *dataLen))
+    if (!validateRequestLength(subCmd, requestLength))
     {
         return IPMI_CC_INVALID;
     }
 
-    /* TODO: This could be cleaner to just have a function pointer table, may
-     * transition in later patchset.
-     */
-    switch (subCmd)
+    auto handler = getCommandHandler(subCmd);
+    if (handler)
     {
-        case FlashSubCmds::flashStartTransfer:
-            return startTransfer(flashUpdateSingleton.get(), reqBuf,
-                                 replyCmdBuf, dataLen);
-        case FlashSubCmds::flashDataBlock:
-            return dataBlock(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
-                             dataLen);
-        case FlashSubCmds::flashDataFinish:
-            return dataFinish(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
-                              dataLen);
-        case FlashSubCmds::flashStartHash:
-            return startHash(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
-                             dataLen);
-        case FlashSubCmds::flashHashData:
-            return hashBlock(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
-                             dataLen);
-        case FlashSubCmds::flashHashFinish:
-            return hashFinish(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
-                              dataLen);
-        default:
-            return IPMI_CC_INVALID;
+        return handler(flashUpdateSingleton.get(), reqBuf, replyCmdBuf,
+                       dataLen);
     }
 
     return IPMI_CC_INVALID;
diff --git a/test/Makefile.am b/test/Makefile.am
index 350c3a8..4230ce0 100644
--- a/test/Makefile.am
+++ b/test/Makefile.am
@@ -16,7 +16,8 @@
 	ipmi_starthash_unittest \
 	ipmi_hashdata_unittest \
 	ipmi_hashfinish_unittest \
-	ipmi_validate_unittest
+	ipmi_validate_unittest \
+	ipmi_command_unittest
 
 TESTS = $(check_PROGRAMS)
 
@@ -40,3 +41,6 @@
 
 ipmi_validate_unittest_SOURCES = ipmi_validate_unittest.cpp
 ipmi_validate_unittest_LDADD = $(top_builddir)/ipmi.o
+
+ipmi_command_unittest_SOURCES = ipmi_command_unittest.cpp
+ipmi_command_unittest_LDADD = $(top_builddir)/ipmi.o
diff --git a/test/ipmi_command_unittest.cpp b/test/ipmi_command_unittest.cpp
new file mode 100644
index 0000000..50d1f20
--- /dev/null
+++ b/test/ipmi_command_unittest.cpp
@@ -0,0 +1,42 @@
+#include "ipmi.hpp"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void EqualFunctions(IpmiFlashHandler lhs, IpmiFlashHandler rhs)
+{
+    EXPECT_FALSE(lhs == nullptr);
+    EXPECT_FALSE(rhs == nullptr);
+    ipmi_ret_t (*const* lPtr)(UpdateInterface*, const uint8_t*, uint8_t*,
+                              size_t*) =
+        lhs.target<ipmi_ret_t (*)(UpdateInterface*, const uint8_t*, uint8_t*,
+                                  size_t*)>();
+    ipmi_ret_t (*const* rPtr)(UpdateInterface*, const uint8_t*, uint8_t*,
+                              size_t*) =
+        rhs.target<ipmi_ret_t (*)(UpdateInterface*, const uint8_t*, uint8_t*,
+                                  size_t*)>();
+    EXPECT_TRUE(lPtr);
+    EXPECT_TRUE(rPtr);
+    EXPECT_EQ(*lPtr, *rPtr);
+    return;
+}
+}
+
+TEST(IpmiCommandTest, VerifyCommandReturnsExpected)
+{
+    // Given a subcommand that's valid, make sure it returns the expected
+    // pointer.
+
+    auto result = getCommandHandler(FlashSubCmds::flashHashFinish);
+    EqualFunctions(hashFinish, result);
+}
+
+TEST(IpmiCommandTest, VerifyInvalidCommandReturnsNull)
+{
+    // Given a subcommand that's invalid, make sure it returns the nullptr.
+
+    auto result = getCommandHandler(static_cast<FlashSubCmds>(25));
+    EXPECT_EQ(result, nullptr);
+}