Add option to skip p2a bridge disable

Add a new constructor for P2aDataHandler with skipBridgeDisable bool to
enable an option to skip disableBridge() in AspeedPciBridge and
NuvotonPciBridge.

Signed-off-by: Willy Tu <wltu@google.com>
Change-Id: I439bbaa2b7295adc54a8aa98157db60a7e820837
diff --git a/tools/main.cpp b/tools/main.cpp
index 3876844..f749e3b 100644
--- a/tools/main.cpp
+++ b/tools/main.cpp
@@ -44,13 +44,14 @@
 
 #define IPMILPC "ipmilpc"
 #define IPMIPCI "ipmipci"
+#define IPMIPCI_SKIP_BRIDGE_DISABLE "ipmipci-skip-bridge-disable"
 #define IPMIBT "ipmibt"
 #define IPMINET "ipminet"
 
 namespace
 {
-const std::vector<std::string> interfaceList = {IPMINET, IPMIBT, IPMILPC,
-                                                IPMIPCI};
+const std::vector<std::string> interfaceList = {
+    IPMINET, IPMIBT, IPMILPC, IPMIPCI, IPMIPCI_SKIP_BRIDGE_DISABLE};
 } // namespace
 
 void usage(const char* program)
@@ -254,6 +255,12 @@
             handler = std::make_unique<host_tool::P2aDataHandler>(&blob, &pci,
                                                                   &progress);
         }
+        else if (interface == IPMIPCI_SKIP_BRIDGE_DISABLE)
+        {
+            auto& pci = host_tool::PciAccessImpl::getInstance();
+            handler = std::make_unique<host_tool::P2aDataHandler>(
+                &blob, &pci, &progress, true);
+        }
 
         if (!handler)
         {
diff --git a/tools/p2a.cpp b/tools/p2a.cpp
index c18df99..090b9bd 100644
--- a/tools/p2a.cpp
+++ b/tools/p2a.cpp
@@ -55,14 +55,14 @@
 
     try
     {
-        bridge = std::make_unique<NuvotonPciBridge>(pci);
+        bridge = std::make_unique<NuvotonPciBridge>(pci, skipBridgeDisable);
     }
     catch (NotFoundException& e)
     {}
 
     try
     {
-        bridge = std::make_unique<AspeedPciBridge>(pci);
+        bridge = std::make_unique<AspeedPciBridge>(pci, skipBridgeDisable);
     }
     catch (NotFoundException& e)
     {}
diff --git a/tools/p2a.hpp b/tools/p2a.hpp
index 08153d2..259ca23 100644
--- a/tools/p2a.hpp
+++ b/tools/p2a.hpp
@@ -16,11 +16,18 @@
 class P2aDataHandler : public DataInterface
 {
   public:
+    explicit P2aDataHandler(ipmiblob::BlobInterface* blob, const PciAccess* pci,
+                            ProgressInterface* progress, bool skipBridgeDisable,
+                            const internal::Sys* sys = &internal::sys_impl) :
+        blob(blob),
+        pci(pci), progress(progress), skipBridgeDisable(skipBridgeDisable),
+        sys(sys)
+    {}
+
     P2aDataHandler(ipmiblob::BlobInterface* blob, const PciAccess* pci,
                    ProgressInterface* progress,
                    const internal::Sys* sys = &internal::sys_impl) :
-        blob(blob),
-        pci(pci), progress(progress), sys(sys)
+        P2aDataHandler(blob, pci, progress, false, sys)
     {}
 
     bool sendContents(const std::string& input, std::uint16_t session) override;
@@ -33,6 +40,7 @@
     ipmiblob::BlobInterface* blob;
     const PciAccess* pci;
     ProgressInterface* progress;
+    bool skipBridgeDisable;
     const internal::Sys* sys;
 };
 
diff --git a/tools/pci.hpp b/tools/pci.hpp
index 3aeae7b..6c901ca 100644
--- a/tools/pci.hpp
+++ b/tools/pci.hpp
@@ -82,15 +82,18 @@
 class NuvotonPciBridge : public PciAccessBridge
 {
   public:
-    explicit NuvotonPciBridge(const PciAccess* pci) :
-        PciAccessBridge(&match, bar, dataOffset, dataLength, pci)
+    explicit NuvotonPciBridge(const PciAccess* pci,
+                              bool skipBridgeDisable = false) :
+        PciAccessBridge(&match, bar, dataOffset, dataLength, pci),
+        skipBridgeDisable(skipBridgeDisable)
     {
         enableBridge();
     }
 
     ~NuvotonPciBridge()
     {
-        disableBridge();
+        if (!skipBridgeDisable)
+            disableBridge();
     }
 
   private:
@@ -110,20 +113,25 @@
 
     void enableBridge();
     void disableBridge();
+
+    bool skipBridgeDisable;
 };
 
 class AspeedPciBridge : public PciAccessBridge
 {
   public:
-    explicit AspeedPciBridge(const PciAccess* pci) :
-        PciAccessBridge(&match, bar, dataOffset, dataLength, pci)
+    explicit AspeedPciBridge(const PciAccess* pci,
+                             bool skipBridgeDisable = false) :
+        PciAccessBridge(&match, bar, dataOffset, dataLength, pci),
+        skipBridgeDisable(skipBridgeDisable)
     {
         enableBridge();
     }
 
     ~AspeedPciBridge()
     {
-        disableBridge();
+        if (!skipBridgeDisable)
+            disableBridge();
     }
 
     void configure(const ipmi_flash::PciConfigResponse& configResp) override;
@@ -146,6 +154,8 @@
 
     void enableBridge();
     void disableBridge();
+
+    bool skipBridgeDisable;
 };
 
 } // namespace host_tool
diff --git a/tools/test/tools_pci_unittest.cpp b/tools/test/tools_pci_unittest.cpp
index 4a45468..7f69db2 100644
--- a/tools/test/tools_pci_unittest.cpp
+++ b/tools/test/tools_pci_unittest.cpp
@@ -71,7 +71,8 @@
     virtual struct pci_device getDevice() const = 0;
     virtual void expectSetup(PciAccessMock& pciMock,
                              const struct pci_device& dev) const {};
-    virtual std::unique_ptr<PciBridgeIntf> getBridge(PciAccess* pci) const = 0;
+    virtual std::unique_ptr<PciBridgeIntf>
+        getBridge(PciAccess* pci, bool skipBridgeDisable = false) const = 0;
     virtual std::string getName() const = 0;
 };
 
@@ -120,9 +121,10 @@
             .WillOnce(Return(0));
     }
 
-    std::unique_ptr<PciBridgeIntf> getBridge(PciAccess* pci) const override
+    std::unique_ptr<PciBridgeIntf>
+        getBridge(PciAccess* pci, bool skipBridgeDisable = false) const override
     {
-        return std::make_unique<NuvotonPciBridge>(pci);
+        return std::make_unique<NuvotonPciBridge>(pci, skipBridgeDisable);
     }
 
     std::string getName() const override
@@ -164,9 +166,10 @@
         return dev;
     }
 
-    std::unique_ptr<PciBridgeIntf> getBridge(PciAccess* pci) const override
+    std::unique_ptr<PciBridgeIntf>
+        getBridge(PciAccess* pci, bool skipBridgeDisable = false) const override
     {
-        return std::make_unique<AspeedPciBridge>(pci);
+        return std::make_unique<AspeedPciBridge>(pci, skipBridgeDisable);
     }
 
     std::string getName() const override
@@ -551,6 +554,33 @@
     nuvotonDevice.getBridge(&pciMock);
 }
 
+/* Make sure it skips the disable bridge call when skipBridgeDisable is true */
+TEST(NuvotonBridgeTest, SkipDisable)
+{
+    PciAccessMock pciMock;
+    struct pci_device dev;
+    std::vector<std::uint8_t> region(mockRegionSize);
+
+    constexpr std::uint8_t defaultVal = 0x40;
+
+    /* Only set standard expectations; not those from nuvotonDevice */
+    expectSetup(pciMock, dev, &nuvotonDevice, region.data(), false);
+
+    {
+        InSequence in;
+
+        /* Only expect call for enableBridge() */
+        EXPECT_CALL(pciMock, pci_device_cfg_read_u8(Eq(&dev), NotNull(),
+                                                    NuvotonDevice::config))
+            .WillOnce(DoAll(
+                SetArgPointee<1>(defaultVal | NuvotonDevice::bridgeEnabled),
+                Return(0)));
+    }
+
+    /* Setting skipBridgeDisable to true */
+    nuvotonDevice.getBridge(&pciMock, true);
+}
+
 TEST(AspeedWriteTest, TooLarge)
 {
     PciAccessMock pciMock;
@@ -644,6 +674,41 @@
     }
 }
 
+/* Make sure the config region remains the same even after cleanup if
+ * skipBridgeDisable is true */
+TEST(AspeedBridgeTest, SkipDisable)
+{
+    PciAccessMock pciMock;
+    struct pci_device dev;
+    std::vector<std::uint8_t> region(mockRegionSize);
+
+    constexpr std::uint8_t defaultVal = 0x42;
+
+    region[AspeedDevice::config] = defaultVal | AspeedDevice::bridgeEnabled;
+
+    expectSetup(pciMock, dev, &aspeedDevice, region.data());
+
+    /* Setting skipBridgeDisable to true */
+    std::unique_ptr<PciBridgeIntf> bridge =
+        aspeedDevice.getBridge(&pciMock, true);
+
+    {
+        std::vector<std::uint8_t> enabledRegion(mockRegionSize);
+        enabledRegion[AspeedDevice::config] =
+            defaultVal | AspeedDevice::bridgeEnabled;
+        EXPECT_THAT(region, ContainerEq(enabledRegion));
+    }
+
+    bridge.reset();
+
+    {
+        std::vector<std::uint8_t> disabledRegion(mockRegionSize);
+        disabledRegion[AspeedDevice::config] =
+            defaultVal | AspeedDevice::bridgeEnabled;
+        EXPECT_THAT(region, ContainerEq(disabledRegion));
+    }
+}
+
 /* Make sure the bridge gets enabled when needed */
 TEST(AspeedBridgeTest, NotEnabledSuccess)
 {