tools: add network bridge support

Add support to the host tool for sending the image data over a network
connection.

Signed-off-by: Benjamin Fair <benjaminfair@google.com>
Change-Id: I88630d79499254d6c80ceaa8c7721c241d394fc8
diff --git a/tools/Makefile.am b/tools/Makefile.am
index 14d9618..bb2a317 100644
--- a/tools/Makefile.am
+++ b/tools/Makefile.am
@@ -20,6 +20,7 @@
 	bt.cpp \
 	lpc.cpp \
 	io.cpp \
+	net.cpp \
 	pci.cpp \
 	p2a.cpp \
 	progress.cpp
diff --git a/tools/main.cpp b/tools/main.cpp
index f73bcad..9746a6d 100644
--- a/tools/main.cpp
+++ b/tools/main.cpp
@@ -17,6 +17,7 @@
 #include "bt.hpp"
 #include "io.hpp"
 #include "lpc.hpp"
+#include "net.hpp"
 #include "p2a.hpp"
 #include "pci.hpp"
 #include "progress.hpp"
@@ -42,10 +43,12 @@
 #define IPMILPC "ipmilpc"
 #define IPMIPCI "ipmipci"
 #define IPMIBT "ipmibt"
+#define IPMINET "ipminet"
 
 namespace
 {
-const std::vector<std::string> interfaceList = {IPMIBT, IPMILPC, IPMIPCI};
+const std::vector<std::string> interfaceList = {IPMINET, IPMIBT, IPMILPC,
+                                                IPMIPCI};
 } // namespace
 
 void usage(const char* program)
@@ -81,7 +84,8 @@
 
 int main(int argc, char* argv[])
 {
-    std::string command, interface, imagePath, signaturePath, type;
+    std::string command, interface, imagePath, signaturePath, type, host;
+    std::string port = "623";
     char* valueEnd = nullptr;
     long address = 0;
     long length = 0;
@@ -101,12 +105,14 @@
             {"length", required_argument, 0, 'l'},
             {"type", required_argument, 0, 't'},
             {"ignore-update", no_argument, 0, 'u'},
+            {"host", required_argument, 0, 'H'},
+            {"port", optional_argument, 0, 'p'},
             {0, 0, 0, 0}
         };
         // clang-format on
 
         int option_index = 0;
-        int c = getopt_long(argc, argv, "c:i:m:s:a:l:t:u", long_options,
+        int c = getopt_long(argc, argv, "c:i:m:s:a:l:t:uH:p:", long_options,
                             &option_index);
         if (c == -1)
         {
@@ -174,6 +180,12 @@
             case 'u':
                 ignoreUpdate = true;
                 break;
+            case 'H':
+                host = std::string{optarg};
+                break;
+            case 'p':
+                port = std::string{optarg};
+                break;
             default:
                 usage(argv[0]);
                 exit(EXIT_FAILURE);
@@ -210,6 +222,16 @@
             handler =
                 std::make_unique<host_tool::BtDataHandler>(&blob, &progress);
         }
+        else if (interface == IPMINET)
+        {
+            if (host.empty())
+            {
+                std::fprintf(stderr, "Host not specified\n");
+                exit(EXIT_FAILURE);
+            }
+            handler = std::make_unique<host_tool::NetDataHandler>(
+                &blob, &progress, host, port);
+        }
         else if (interface == IPMILPC)
         {
             if (hostAddress == 0 || hostLength == 0)
diff --git a/tools/net.cpp b/tools/net.cpp
new file mode 100644
index 0000000..4d2ce13
--- /dev/null
+++ b/tools/net.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright 2019 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "net.hpp"
+
+#include "data.hpp"
+#include "flags.hpp"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <sys/sendfile.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <cstdint>
+#include <cstring>
+#include <ipmiblob/blob_errors.hpp>
+#include <memory>
+#include <stdplus/handle/managed.hpp>
+#include <string>
+#include <vector>
+
+namespace
+{
+
+void closefd(int&& fd, const internal::Sys*& sys)
+{
+    sys->close(fd);
+}
+using Fd = stdplus::Managed<int, const internal::Sys*>::Handle<closefd>;
+
+} // namespace
+
+namespace host_tool
+{
+
+bool NetDataHandler::sendContents(const std::string& input,
+                                  std::uint16_t session)
+{
+    constexpr size_t blockSize = 64 * 1024;
+    Fd inputFd(std::nullopt, sys);
+
+    {
+        inputFd.reset(sys->open(input.c_str(), O_RDONLY));
+        if (*inputFd < 0)
+        {
+            (void)inputFd.release();
+            std::fprintf(stderr, "Unable to open file: '%s'\n", input.c_str());
+            return false;
+        }
+
+        std::int64_t fileSize = sys->getSize(input.c_str());
+        if (fileSize == 0)
+        {
+            std::fprintf(stderr,
+                         "Zero-length file, or other file access error\n");
+            return false;
+        }
+
+        progress->start(fileSize);
+    }
+
+    Fd connFd(std::nullopt, sys);
+
+    {
+        struct addrinfo hints;
+        std::memset(&hints, 0, sizeof(hints));
+        hints.ai_flags = AI_NUMERICHOST;
+        hints.ai_family = AF_INET;
+        hints.ai_socktype = SOCK_STREAM;
+
+        struct addrinfo *addrs, *addr;
+        int ret = sys->getaddrinfo(host.c_str(), port.c_str(), &hints, &addrs);
+        if (ret < 0)
+        {
+            std::fprintf(stderr, "Couldn't parse address %s with port %s: %s\n",
+                         host.c_str(), port.c_str(), gai_strerror(ret));
+            return false;
+        }
+
+        for (addr = addrs; addr != nullptr; addr = addr->ai_next)
+        {
+            connFd.reset(sys->socket(addr->ai_family, addr->ai_socktype,
+                                     addr->ai_protocol));
+            if (*connFd == -1)
+                continue;
+
+            if (sys->connect(*connFd, addr->ai_addr, addr->ai_addrlen) != -1)
+                break;
+        }
+
+        // TODO: use stdplus Managed for the addrinfo structs
+        sys->freeaddrinfo(addrs);
+
+        if (addr == nullptr)
+        {
+            std::fprintf(stderr, "Failed to connect\n");
+            return false;
+        }
+    }
+
+    try
+    {
+        int bytesSent = 0;
+        off_t offset = 0;
+
+        do
+        {
+            bytesSent = sys->sendfile(*connFd, *inputFd, &offset, blockSize);
+            if (bytesSent < 0)
+            {
+                std::fprintf(stderr, "Failed to send data to BMC: %s\n",
+                             strerror(errno));
+                return false;
+            }
+            else if (bytesSent > 0)
+            {
+                /* Ok, so the data is staged, now send the blob write with
+                 * the details.
+                 */
+                struct ipmi_flash::ExtChunkHdr chunk;
+                chunk.length = bytesSent;
+                std::vector<std::uint8_t> chunkBytes(sizeof(chunk));
+                std::memcpy(chunkBytes.data(), &chunk, sizeof(chunk));
+
+                /* This doesn't return anything on success. */
+                blob->writeBytes(session, offset - bytesSent, chunkBytes);
+                progress->updateProgress(bytesSent);
+            }
+        } while (bytesSent > 0);
+    }
+    catch (const ipmiblob::BlobException& b)
+    {
+        return false;
+    }
+
+    return true;
+}
+
+} // namespace host_tool
diff --git a/tools/net.hpp b/tools/net.hpp
new file mode 100644
index 0000000..ba329df
--- /dev/null
+++ b/tools/net.hpp
@@ -0,0 +1,40 @@
+#pragma once
+
+#include "interface.hpp"
+#include "internal/sys.hpp"
+#include "progress.hpp"
+
+#include <unistd.h>
+
+#include <cstdint>
+#include <ipmiblob/blob_interface.hpp>
+#include <stdplus/handle/managed.hpp>
+#include <string>
+
+namespace host_tool
+{
+
+class NetDataHandler : public DataInterface
+{
+  public:
+    NetDataHandler(ipmiblob::BlobInterface* blob, ProgressInterface* progress,
+                   const std::string& host, const std::string& port,
+                   const internal::Sys* sys = &internal::sys_impl) :
+        blob(blob),
+        progress(progress), host(host), port(port), sys(sys){};
+
+    bool sendContents(const std::string& input, std::uint16_t session) override;
+    ipmi_flash::FirmwareFlags::UpdateFlags supportedType() const override
+    {
+        return ipmi_flash::FirmwareFlags::UpdateFlags::net;
+    }
+
+  private:
+    ipmiblob::BlobInterface* blob;
+    ProgressInterface* progress;
+    std::string host;
+    std::string port;
+    const internal::Sys* sys;
+};
+
+} // namespace host_tool
diff --git a/tools/test/Makefile.am b/tools/test/Makefile.am
index c05654b..1fe35a0 100644
--- a/tools/test/Makefile.am
+++ b/tools/test/Makefile.am
@@ -18,6 +18,7 @@
 check_PROGRAMS = \
 	tools_bt_unittest \
 	tools_lpc_unittest \
+	tools_net_unittest \
 	tools_updater_unittest \
 	tools_helper_unittest
 
@@ -29,6 +30,9 @@
 tools_lpc_unittest_SOURCES = tools_lpc_unittest.cpp
 tools_lpc_unittest_LDADD = $(top_builddir)/tools/libupdater.la
 
+tools_net_unittest_SOURCES = tools_net_unittest.cpp
+tools_net_unittest_LDADD = $(top_builddir)/tools/libupdater.la
+
 tools_updater_unittest_SOURCES = tools_updater_unittest.cpp
 tools_updater_unittest_LDADD = $(top_builddir)/tools/libupdater.la
 
diff --git a/tools/test/tools_net_unittest.cpp b/tools/test/tools_net_unittest.cpp
new file mode 100644
index 0000000..51556ff
--- /dev/null
+++ b/tools/test/tools_net_unittest.cpp
@@ -0,0 +1,230 @@
+#include "data.hpp"
+#include "internal_sys_mock.hpp"
+#include "net.hpp"
+#include "progress_mock.hpp"
+
+#include <cstring>
+#include <ipmiblob/test/blob_interface_mock.hpp>
+
+#include <gtest/gtest.h>
+
+namespace host_tool
+{
+namespace
+{
+
+using namespace std::literals;
+
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::ContainerEq;
+using ::testing::Field;
+using ::testing::Gt;
+using ::testing::InSequence;
+using ::testing::NotNull;
+using ::testing::Pointee;
+using ::testing::Return;
+using ::testing::SetArgPointee;
+using ::testing::SetErrnoAndReturn;
+using ::testing::StrEq;
+
+class NetHandleTest : public ::testing::Test
+{
+  protected:
+    NetHandleTest() : handler(&blobMock, &progMock, host, port, &sysMock)
+    {
+        sa.sin6_family = AF_INET6;
+        sa.sin6_port = htons(622);
+        sa.sin6_flowinfo = 0;
+        sa.sin6_addr = in6addr_loopback; // ::1
+        sa.sin6_scope_id = 0;
+
+        addr.ai_family = AF_INET6;
+        addr.ai_socktype = SOCK_STREAM;
+        addr.ai_addr = reinterpret_cast<struct sockaddr*>(&sa);
+        addr.ai_addrlen = sizeof(sa);
+        addr.ai_protocol = 0;
+        addr.ai_next = nullptr;
+    }
+
+    void expectOpenFile()
+    {
+        EXPECT_CALL(sysMock, open(StrEq(filePath.c_str()), _))
+            .WillOnce(Return(inFd));
+        EXPECT_CALL(sysMock, close(inFd)).WillOnce(Return(0));
+        EXPECT_CALL(sysMock, getSize(StrEq(filePath.c_str())))
+            .WillOnce(Return(fakeFileSize));
+
+        EXPECT_CALL(progMock, start(fakeFileSize));
+    }
+
+    void expectAddrInfo()
+    {
+        EXPECT_CALL(
+            sysMock,
+            getaddrinfo(StrEq(host), StrEq(port),
+                        AllOf(Field(&addrinfo::ai_flags, AI_NUMERICHOST),
+                              Field(&addrinfo::ai_family, AF_INET),
+                              Field(&addrinfo::ai_socktype, SOCK_STREAM)),
+                        NotNull()))
+            .WillOnce(DoAll(SetArgPointee<3>(&addr), Return(0)));
+        EXPECT_CALL(sysMock, freeaddrinfo(&addr));
+    }
+
+    void expectConnection()
+    {
+        EXPECT_CALL(sysMock, socket(AF_INET6, SOCK_STREAM, 0))
+            .WillOnce(Return(connFd));
+        EXPECT_CALL(sysMock, close(connFd)).WillOnce(Return(0));
+        EXPECT_CALL(sysMock,
+                    connect(connFd, reinterpret_cast<struct sockaddr*>(&sa),
+                            sizeof(sa)))
+            .WillOnce(Return(0));
+    }
+
+    internal::InternalSysMock sysMock;
+    ipmiblob::BlobInterfaceMock blobMock;
+    ProgressMock progMock;
+
+    const std::string host = "::1"s;
+    const std::string port = "622"s;
+
+    struct sockaddr_in6 sa;
+    struct addrinfo addr;
+
+    static constexpr std::uint16_t session = 0xbeef;
+    const std::string filePath = "/asdf"s;
+    static constexpr int inFd = 5;
+    static constexpr int connFd = 7;
+    static constexpr size_t fakeFileSize = 128;
+    static constexpr size_t chunkSize = 16;
+
+    NetDataHandler handler;
+};
+
+TEST_F(NetHandleTest, openFileFail)
+{
+    EXPECT_CALL(sysMock, open(StrEq(filePath.c_str()), _))
+        .WillOnce(SetErrnoAndReturn(EACCES, -1));
+
+    EXPECT_FALSE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, getSizeFail)
+{
+    EXPECT_CALL(sysMock, open(StrEq(filePath.c_str()), _))
+        .WillOnce(Return(inFd));
+    EXPECT_CALL(sysMock, close(inFd)).WillOnce(Return(0));
+    EXPECT_CALL(sysMock, getSize(StrEq(filePath.c_str()))).WillOnce(Return(0));
+
+    EXPECT_FALSE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, getaddrinfoFail)
+{
+    expectOpenFile();
+
+    EXPECT_CALL(sysMock,
+                getaddrinfo(StrEq(host), StrEq(port),
+                            AllOf(Field(&addrinfo::ai_flags, AI_NUMERICHOST),
+                                  Field(&addrinfo::ai_family, AF_INET),
+                                  Field(&addrinfo::ai_socktype, SOCK_STREAM)),
+                            NotNull()))
+        .WillOnce(Return(EAI_ADDRFAMILY));
+
+    EXPECT_FALSE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, connectFail)
+{
+    expectOpenFile();
+    expectAddrInfo();
+
+    EXPECT_CALL(sysMock, socket(AF_INET6, SOCK_STREAM, 0))
+        .WillOnce(Return(connFd));
+    EXPECT_CALL(sysMock, close(connFd)).WillOnce(Return(0));
+    EXPECT_CALL(
+        sysMock,
+        connect(connFd, reinterpret_cast<struct sockaddr*>(&sa), sizeof(sa)))
+        .WillOnce(SetErrnoAndReturn(ECONNREFUSED, -1));
+
+    EXPECT_FALSE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, sendfileFail)
+{
+    expectOpenFile();
+    expectAddrInfo();
+    expectConnection();
+
+    EXPECT_CALL(sysMock, sendfile(connFd, inFd, Pointee(0), _))
+        .WillOnce(SetErrnoAndReturn(ETIMEDOUT, -1));
+
+    EXPECT_FALSE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, successOneChunk)
+{
+    expectOpenFile();
+    expectAddrInfo();
+    expectConnection();
+
+    {
+        InSequence seq;
+
+        EXPECT_CALL(sysMock,
+                    sendfile(connFd, inFd, Pointee(0), Gt(fakeFileSize)))
+            .WillOnce(
+                DoAll(SetArgPointee<2>(fakeFileSize), Return(fakeFileSize)));
+        EXPECT_CALL(sysMock, sendfile(connFd, inFd, Pointee(fakeFileSize),
+                                      Gt(fakeFileSize)))
+            .WillOnce(Return(0));
+    }
+
+    struct ipmi_flash::ExtChunkHdr chunk;
+    chunk.length = fakeFileSize;
+    std::vector<std::uint8_t> chunkBytes(sizeof(chunk));
+    std::memcpy(chunkBytes.data(), &chunk, sizeof(chunk));
+    EXPECT_CALL(blobMock, writeBytes(session, 0, ContainerEq(chunkBytes)));
+
+    EXPECT_CALL(progMock, updateProgress(fakeFileSize));
+
+    EXPECT_TRUE(handler.sendContents(filePath, session));
+}
+
+TEST_F(NetHandleTest, successMultiChunk)
+{
+    expectOpenFile();
+    expectAddrInfo();
+    expectConnection();
+
+    struct ipmi_flash::ExtChunkHdr chunk;
+    chunk.length = chunkSize;
+    std::vector<std::uint8_t> chunkBytes(sizeof(chunk));
+    std::memcpy(chunkBytes.data(), &chunk, sizeof(chunk));
+
+    {
+        InSequence seq;
+
+        for (std::uint32_t offset = 0; offset < fakeFileSize;
+             offset += chunkSize)
+        {
+            EXPECT_CALL(sysMock,
+                        sendfile(connFd, inFd, Pointee(offset), Gt(chunkSize)))
+                .WillOnce(DoAll(SetArgPointee<2>(offset + chunkSize),
+                                Return(chunkSize)));
+
+            EXPECT_CALL(blobMock,
+                        writeBytes(session, offset, ContainerEq(chunkBytes)));
+            EXPECT_CALL(progMock, updateProgress(chunkSize));
+        }
+        EXPECT_CALL(sysMock, sendfile(connFd, inFd, Pointee(fakeFileSize),
+                                      Gt(chunkSize)))
+            .WillOnce(Return(0));
+    }
+
+    EXPECT_TRUE(handler.sendContents(filePath, session));
+}
+
+} // namespace
+} // namespace host_tool