nbd proxy and websocket cleanups

As-written, the nbd (and all websocket daemons) suffer from a problem
where there is no way to apply socket backpressure, so in certain
conditions, it's trivial to run the BMC out of memory on a given
message.  This is a problem.

This commit implements the idea of an incremental callback handler, that
accepts a callback function to be run when the processing of the message
is complete.  This allows applying backpressure on the socket, which in
turn, should provide pressure back to the client, and prevent buffering
crashes on slow connections, or connections with high latency.

Tested: NBD proxy not upstream, no way to test.  No changes made to
normal websocket flow.

Signed-off-by: Michal Orzel <michalx.orzel@intel.com>
Signed-off-by: Ed Tanous <edtanous@google.com>
Change-Id: I3f116cc91eeadc949579deacbeb2d9f5e0f4fa53
diff --git a/include/dbus_utility.hpp b/include/dbus_utility.hpp
index 854c2c1..73952a2 100644
--- a/include/dbus_utility.hpp
+++ b/include/dbus_utility.hpp
@@ -16,6 +16,7 @@
 #pragma once
 
 #include "dbus_singleton.hpp"
+#include "logging.hpp"
 
 #include <boost/system/error_code.hpp> // IWYU pragma: keep
 #include <sdbusplus/asio/property.hpp>
@@ -104,6 +105,14 @@
     std::regex_replace(path.begin(), path.begin(), path.end(), reg, "_");
 }
 
+inline void logError(const boost::system::error_code& ec)
+{
+    if (ec)
+    {
+        BMCWEB_LOG_ERROR << "DBus error: " << ec << ", cannot call method";
+    }
+}
+
 // gets the string N strings deep into a path
 // i.e.  /0th/1st/2nd/3rd
 inline bool getNthStringFromPath(const std::string& path, int index,
diff --git a/include/nbd_proxy.hpp b/include/nbd_proxy.hpp
index c0b2907..e3d4c4b 100644
--- a/include/nbd_proxy.hpp
+++ b/include/nbd_proxy.hpp
@@ -18,14 +18,14 @@
 #include "dbus_utility.hpp"
 #include "privileges.hpp"
 
-#include <boost/asio/buffer.hpp>
 #include <boost/asio/local/stream_protocol.hpp>
 #include <boost/asio/write.hpp>
 #include <boost/beast/core/buffers_to_string.hpp>
-#include <boost/beast/core/multi_buffer.hpp>
 #include <boost/container/flat_map.hpp>
 #include <websocket.hpp>
 
+#include <string_view>
+
 namespace crow
 {
 
@@ -34,7 +34,7 @@
 
 using boost::asio::local::stream_protocol;
 
-static constexpr auto nbdBufferSize = 131088;
+static constexpr size_t nbdBufferSize = 131088;
 constexpr const char* requiredPrivilegeString = "ConfigureManager";
 
 struct NbdProxyServer : std::enable_shared_from_this<NbdProxyServer>
@@ -44,6 +44,8 @@
                    const std::string& endpointIdIn, const std::string& pathIn) :
         socketId(socketIdIn),
         endpointId(endpointIdIn), path(pathIn),
+
+        peerSocket(connIn.getIoContext()),
         acceptor(connIn.getIoContext(), stream_protocol::endpoint(socketId)),
         connection(connIn)
     {}
@@ -56,6 +58,17 @@
     ~NbdProxyServer()
     {
         BMCWEB_LOG_DEBUG << "NbdProxyServer destructor";
+
+        BMCWEB_LOG_DEBUG << "peerSocket->close()";
+        boost::system::error_code ec;
+        peerSocket.close(ec);
+
+        BMCWEB_LOG_DEBUG << "std::remove(" << socketId << ")";
+        std::remove(socketId.c_str());
+
+        crow::connections::systemBus->async_method_call(
+            dbus::utility::logError, "xyz.openbmc_project.VirtualMedia", path,
+            "xyz.openbmc_project.VirtualMedia.Proxy", "Unmount");
     }
 
     std::string getEndpointId() const
@@ -65,41 +78,42 @@
 
     void run()
     {
-        acceptor.async_accept([this, self(shared_from_this())](
-                                  const boost::system::error_code& ec,
-                                  stream_protocol::socket socket) {
+        acceptor.async_accept(
+            [weak(weak_from_this())](const boost::system::error_code& ec,
+                                     stream_protocol::socket socket) {
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX socket: async_accept error = "
                                  << ec.message();
                 return;
             }
-            if (peerSocket)
+
+            BMCWEB_LOG_DEBUG << "Connection opened";
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
             {
-                // Something is wrong - socket shouldn't be acquired at this
-                // point
-                BMCWEB_LOG_ERROR
-                    << "Failed to open connection - socket already used";
                 return;
             }
 
-            BMCWEB_LOG_DEBUG << "Connection opened";
-            peerSocket = std::move(socket);
-            doRead();
-
-            // Trigger Write if any data was sent from server
-            // Initially this is negotiation chunk
-            doWrite();
+            self->connection.resumeRead();
+            self->peerSocket = std::move(socket);
+            //  Start reading from socket
+            self->doRead();
         });
 
-        auto mountHandler =
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, const bool) {
+        auto mountHandler = [weak(weak_from_this())](
+                                const boost::system::error_code& ec, bool) {
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
+                return;
+            }
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "DBus error: cannot call mount method = "
                                  << ec.message();
-                connection.close("Failed to mount media");
+
+                self->connection.close("Failed to mount media");
                 return;
             }
         };
@@ -109,90 +123,53 @@
             "xyz.openbmc_project.VirtualMedia.Proxy", "Mount");
     }
 
-    void send(std::string_view data)
+    void send(std::string_view buffer, std::function<void()>&& onDone)
     {
-        boost::asio::buffer_copy(ws2uxBuf.prepare(data.size()),
-                                 boost::asio::buffer(data));
-        ws2uxBuf.commit(data.size());
-        doWrite();
-    }
+        boost::asio::buffer_copy(ws2uxBuf.prepare(buffer.size()),
+                                 boost::asio::buffer(buffer));
+        ws2uxBuf.commit(buffer.size());
 
-    void close()
-    {
-        acceptor.close();
-        if (peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "peerSocket->close()";
-            peerSocket->close();
-            peerSocket.reset();
-            BMCWEB_LOG_DEBUG << "std::remove(" << socketId << ")";
-            std::remove(socketId.c_str());
-        }
-        // The reference to session should exists until unmount is
-        // called
-        auto unmountHandler = [](const boost::system::error_code& ec) {
-            if (ec)
-            {
-                BMCWEB_LOG_ERROR << "DBus error: " << ec
-                                 << ", cannot call unmount method";
-                return;
-            }
-        };
-
-        crow::connections::systemBus->async_method_call(
-            std::move(unmountHandler), "xyz.openbmc_project.VirtualMedia", path,
-            "xyz.openbmc_project.VirtualMedia.Proxy", "Unmount");
+        doWrite(std::move(onDone));
     }
 
   private:
     void doRead()
     {
-        if (!peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "UNIX socket isn't created yet";
-            // Skip if UNIX socket is not created yet.
-            return;
-        }
-
         // Trigger async read
-        peerSocket->async_read_some(
+        peerSocket.async_read_some(
             ux2wsBuf.prepare(nbdBufferSize),
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, std::size_t bytesRead) {
+            [weak(weak_from_this())](const boost::system::error_code& ec,
+                                     size_t bytesRead) {
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX socket: async_read_some error = "
                                  << ec.message();
-                // UNIX socket has been closed by peer, best we can do is to
-                // break all connections
-                close();
+                return;
+            }
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
                 return;
             }
 
-            // Fetch data from UNIX socket
-
-            ux2wsBuf.commit(bytesRead);
-
-            // Paste it to WebSocket as binary
-            connection.sendBinary(
-                boost::beast::buffers_to_string(ux2wsBuf.data()));
-            ux2wsBuf.consume(bytesRead);
-
-            // Allow further reads
-            doRead();
+            // Send to websocket
+            self->ux2wsBuf.commit(bytesRead);
+            self->connection.sendEx(
+                crow::websocket::MessageType::Binary,
+                boost::beast::buffers_to_string(self->ux2wsBuf.data()),
+                [weak(self->weak_from_this())]() {
+                std::shared_ptr<NbdProxyServer> self2 = weak.lock();
+                if (self2 != nullptr)
+                {
+                    self2->ux2wsBuf.consume(self2->ux2wsBuf.size());
+                    self2->doRead();
+                }
+                });
             });
     }
 
-    void doWrite()
+    void doWrite(std::function<void()>&& onDone)
     {
-        if (!peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "UNIX socket isn't created yet";
-            // Skip if UNIX socket is not created yet. Collect data, and wait
-            // for nbd-client connection
-            return;
-        }
-
         if (uxWriteInProgress)
         {
             BMCWEB_LOG_ERROR << "Write in progress";
@@ -206,23 +183,35 @@
         }
 
         uxWriteInProgress = true;
-        boost::asio::async_write(
-            *peerSocket, ws2uxBuf.data(),
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, std::size_t bytesWritten) {
-            ws2uxBuf.consume(bytesWritten);
-            uxWriteInProgress = false;
+        peerSocket.async_write_some(
+            ws2uxBuf.data(),
+            [weak(weak_from_this()),
+             onDone(std::move(onDone))](const boost::system::error_code& ec,
+                                        size_t bytesWritten) mutable {
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
+                return;
+            }
+
+            self->ws2uxBuf.consume(bytesWritten);
+            self->uxWriteInProgress = false;
+
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX: async_write error = "
                                  << ec.message();
+                self->connection.close("Internal error");
                 return;
             }
+
             // Retrigger doWrite if there is something in buffer
-            if (ws2uxBuf.size() > 0)
+            if (self->ws2uxBuf.size() > 0)
             {
-                doWrite();
+                self->doWrite(std::move(onDone));
+                return;
             }
+            onDone();
             });
     }
 
@@ -234,17 +223,17 @@
     bool uxWriteInProgress = false;
 
     // UNIX => WebSocket buffer
-    boost::beast::multi_buffer ux2wsBuf;
+    boost::beast::flat_static_buffer<nbdBufferSize> ux2wsBuf;
 
-    // WebSocket <= UNIX buffer
-    boost::beast::multi_buffer ws2uxBuf;
+    // WebSocket => UNIX buffer
+    boost::beast::flat_static_buffer<nbdBufferSize> ws2uxBuf;
+
+    // The socket used to communicate with the client.
+    stream_protocol::socket peerSocket;
 
     // Default acceptor for UNIX socket
     stream_protocol::acceptor acceptor;
 
-    // The socket used to communicate with the client.
-    std::optional<stream_protocol::socket> peerSocket;
-
     crow::websocket::Connection& connection;
 };
 
@@ -343,11 +332,15 @@
                 const dbus::utility::ManagedObjectType& objects) {
         afterGetManagedObjects(conn, ec, objects);
     };
-
     crow::connections::systemBus->async_method_call(
         std::move(openHandler), "xyz.openbmc_project.VirtualMedia",
         "/xyz/openbmc_project/VirtualMedia",
         "org.freedesktop.DBus.ObjectManager", "GetManagedObjects");
+
+    // We need to wait for dbus and the websockets to hook up before data is
+    // sent/received.  Tell the core to hold off messages until the sockets are
+    // up
+    conn.deferRead();
 }
 
 inline void onClose(crow::websocket::Connection& conn,
@@ -360,25 +353,25 @@
         BMCWEB_LOG_DEBUG << "No session to close";
         return;
     }
-    session->second->close();
     // Remove reference to session in global map
     sessions.erase(session);
 }
 
-inline void onMessage(crow::websocket::Connection& conn,
-                      const std::string& data, bool /*isBinary*/)
+inline void onMessage(crow::websocket::Connection& conn, std::string_view data,
+                      crow::websocket::MessageType /*type*/,
+                      std::function<void()>&& whenComplete)
 {
-    BMCWEB_LOG_DEBUG << "nbd-proxy.onmessage(len = " << data.length() << ")";
+    BMCWEB_LOG_DEBUG << "nbd-proxy.onMessage(len = " << data.size() << ")";
+
     // Acquire proxy from sessions
     auto session = sessions.find(&conn);
-    if (session != sessions.end())
+    if (session == sessions.end() || session->second == nullptr)
     {
-        if (session->second)
-        {
-            session->second->send(data);
-            return;
-        }
+        whenComplete();
+        return;
     }
+
+    session->second->send(data, std::move(whenComplete));
 }
 
 inline void requestRoutes(App& app)
@@ -387,7 +380,7 @@
         .websocket()
         .onopen(onOpen)
         .onclose(onClose)
-        .onmessage(onMessage);
+        .onmessageex(onMessage);
 }
 } // namespace nbd_proxy
 } // namespace crow