nbd proxy and websocket cleanups

As-written, the nbd (and all websocket daemons) suffer from a problem
where there is no way to apply socket backpressure, so in certain
conditions, it's trivial to run the BMC out of memory on a given
message.  This is a problem.

This commit implements the idea of an incremental callback handler, that
accepts a callback function to be run when the processing of the message
is complete.  This allows applying backpressure on the socket, which in
turn, should provide pressure back to the client, and prevent buffering
crashes on slow connections, or connections with high latency.

Tested: NBD proxy not upstream, no way to test.  No changes made to
normal websocket flow.

Signed-off-by: Michal Orzel <michalx.orzel@intel.com>
Signed-off-by: Ed Tanous <edtanous@google.com>
Change-Id: I3f116cc91eeadc949579deacbeb2d9f5e0f4fa53
diff --git a/http/routing.hpp b/http/routing.hpp
index e41ce93..613b54d 100644
--- a/http/routing.hpp
+++ b/http/routing.hpp
@@ -360,7 +360,7 @@
             myConnection = std::make_shared<
                 crow::websocket::ConnectionImpl<boost::asio::ip::tcp::socket>>(
                 req, std::move(adaptor), openHandler, messageHandler,
-                closeHandler, errorHandler);
+                messageExHandler, closeHandler, errorHandler);
         myConnection->start();
     }
 #else
@@ -375,7 +375,7 @@
             myConnection = std::make_shared<crow::websocket::ConnectionImpl<
                 boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>>(
                 req, std::move(adaptor), openHandler, messageHandler,
-                closeHandler, errorHandler);
+                messageExHandler, closeHandler, errorHandler);
         myConnection->start();
     }
 #endif
@@ -395,6 +395,13 @@
     }
 
     template <typename Func>
+    self_t& onmessageex(Func f)
+    {
+        messageExHandler = f;
+        return *this;
+    }
+
+    template <typename Func>
     self_t& onclose(Func f)
     {
         closeHandler = f;
@@ -412,6 +419,10 @@
     std::function<void(crow::websocket::Connection&)> openHandler;
     std::function<void(crow::websocket::Connection&, const std::string&, bool)>
         messageHandler;
+    std::function<void(crow::websocket::Connection&, std::string_view,
+                       crow::websocket::MessageType type,
+                       std::function<void()>&& whenComplete)>
+        messageExHandler;
     std::function<void(crow::websocket::Connection&, const std::string&)>
         closeHandler;
     std::function<void(crow::websocket::Connection&)> errorHandler;
diff --git a/http/websocket.hpp b/http/websocket.hpp
index 216e96f..9a5aa29 100644
--- a/http/websocket.hpp
+++ b/http/websocket.hpp
@@ -3,6 +3,7 @@
 #include "http_request.hpp"
 
 #include <boost/asio/buffer.hpp>
+#include <boost/beast/core/multi_buffer.hpp>
 #include <boost/beast/websocket.hpp>
 
 #include <array>
@@ -17,6 +18,12 @@
 namespace websocket
 {
 
+enum class MessageType
+{
+    Binary,
+    Text,
+};
+
 struct Connection : std::enable_shared_from_this<Connection>
 {
   public:
@@ -30,9 +37,13 @@
 
     virtual void sendBinary(std::string_view msg) = 0;
     virtual void sendBinary(std::string&& msg) = 0;
+    virtual void sendEx(MessageType type, std::string_view msg,
+                        std::function<void()>&& onDone) = 0;
     virtual void sendText(std::string_view msg) = 0;
     virtual void sendText(std::string&& msg) = 0;
     virtual void close(std::string_view msg = "quit") = 0;
+    virtual void deferRead() = 0;
+    virtual void resumeRead() = 0;
     virtual boost::asio::io_context& getIoContext() = 0;
     virtual ~Connection() = default;
 
@@ -48,12 +59,17 @@
         std::function<void(Connection&)> openHandlerIn,
         std::function<void(Connection&, const std::string&, bool)>
             messageHandlerIn,
+        std::function<void(crow::websocket::Connection&, std::string_view,
+                           crow::websocket::MessageType type,
+                           std::function<void()>&& whenComplete)>
+            messageExHandlerIn,
         std::function<void(Connection&, const std::string&)> closeHandlerIn,
         std::function<void(Connection&)> errorHandlerIn) :
         Connection(reqIn),
         ws(std::move(adaptorIn)), inBuffer(inString, 131088),
         openHandler(std::move(openHandlerIn)),
         messageHandler(std::move(messageHandlerIn)),
+        messageExHandler(std::move(messageExHandlerIn)),
         closeHandler(std::move(closeHandlerIn)),
         errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
     {
@@ -126,28 +142,61 @@
     void sendBinary(std::string_view msg) override
     {
         ws.binary(true);
-        outBuffer.emplace_back(msg);
+        outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
+                                                  boost::asio::buffer(msg)));
         doWrite();
     }
 
+    void sendEx(MessageType type, std::string_view msg,
+                std::function<void()>&& onDone) override
+    {
+        if (doingWrite)
+        {
+            BMCWEB_LOG_CRITICAL
+                << "Cannot mix sendEx usage with sendBinary or sendText";
+            onDone();
+            return;
+        }
+        ws.binary(type == MessageType::Binary);
+
+        ws.async_write(boost::asio::buffer(msg),
+                       [weak(weak_from_this()), onDone{std::move(onDone)}](
+                           const boost::beast::error_code& ec, size_t) {
+            std::shared_ptr<Connection> self = weak.lock();
+
+            // Call the done handler regardless of whether we
+            // errored, but before we close things out
+            onDone();
+
+            if (ec)
+            {
+                BMCWEB_LOG_ERROR << "Error in ws.async_write " << ec;
+                self->close("write error");
+            }
+        });
+    }
+
     void sendBinary(std::string&& msg) override
     {
         ws.binary(true);
-        outBuffer.emplace_back(std::move(msg));
+        outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
+                                                  boost::asio::buffer(msg)));
         doWrite();
     }
 
     void sendText(std::string_view msg) override
     {
         ws.text(true);
-        outBuffer.emplace_back(msg);
+        outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
+                                                  boost::asio::buffer(msg)));
         doWrite();
     }
 
     void sendText(std::string&& msg) override
     {
         ws.text(true);
-        outBuffer.emplace_back(std::move(msg));
+        outBuffer.commit(boost::asio::buffer_copy(outBuffer.prepare(msg.size()),
+                                                  boost::asio::buffer(msg)));
         doWrite();
     }
 
@@ -172,19 +221,41 @@
     {
         BMCWEB_LOG_DEBUG << "Websocket accepted connection";
 
-        doRead();
-
         if (openHandler)
         {
             openHandler(*this);
         }
+        doRead();
+    }
+
+    void deferRead() override
+    {
+        readingDefered = true;
+
+        // If we're not actively reading, we need to take ownership of
+        // ourselves for a small portion of time, do that, and clear when we
+        // resume.
+        selfOwned = shared_from_this();
+    }
+
+    void resumeRead() override
+    {
+        readingDefered = false;
+        doRead();
+
+        // No longer need to keep ourselves alive now that read is active.
+        selfOwned.reset();
     }
 
     void doRead()
     {
-        ws.async_read(inBuffer,
-                      [this, self(shared_from_this())](
-                          boost::beast::error_code ec, std::size_t bytesRead) {
+        if (readingDefered)
+        {
+            return;
+        }
+        ws.async_read(inBuffer, [this, self(shared_from_this())](
+                                    const boost::beast::error_code& ec,
+                                    size_t bytesRead) {
             if (ec)
             {
                 if (ec != boost::beast::websocket::error::closed)
@@ -198,16 +269,10 @@
                 }
                 return;
             }
-            if (messageHandler)
-            {
-                messageHandler(*this, inString, ws.got_text());
-            }
-            inBuffer.consume(bytesRead);
-            inString.clear();
-            doRead();
+
+            handleMessage(bytesRead);
         });
     }
-
     void doWrite()
     {
         // If we're already doing a write, ignore the request, it will be picked
@@ -217,17 +282,17 @@
             return;
         }
 
-        if (outBuffer.empty())
+        if (outBuffer.size() == 0)
         {
             // Done for now
             return;
         }
         doingWrite = true;
-        ws.async_write(boost::asio::buffer(outBuffer.front()),
-                       [this, self(shared_from_this())](
-                           boost::beast::error_code ec, std::size_t) {
+        ws.async_write(outBuffer.data(), [this, self(shared_from_this())](
+                                             const boost::beast::error_code& ec,
+                                             size_t bytesSent) {
             doingWrite = false;
-            outBuffer.erase(outBuffer.begin());
+            outBuffer.consume(bytesSent);
             if (ec == boost::beast::websocket::error::closed)
             {
                 // Do nothing here.  doRead handler will call the
@@ -245,21 +310,59 @@
     }
 
   private:
+    void handleMessage(size_t bytesRead)
+    {
+        if (messageExHandler)
+        {
+            // Note, because of the interactions with the read buffers,
+            // this message handler overrides the normal message handler
+            messageExHandler(*this, inString, MessageType::Binary,
+                             [this, self(shared_from_this()), bytesRead]() {
+                if (self == nullptr)
+                {
+                    return;
+                }
+
+                inBuffer.consume(bytesRead);
+                inString.clear();
+
+                doRead();
+            });
+            return;
+        }
+
+        if (messageHandler)
+        {
+            messageHandler(*this, inString, ws.got_text());
+        }
+        inBuffer.consume(bytesRead);
+        inString.clear();
+        doRead();
+    }
+
     boost::beast::websocket::stream<Adaptor, false> ws;
 
+    bool readingDefered = false;
     std::string inString;
     boost::asio::dynamic_string_buffer<std::string::value_type,
                                        std::string::traits_type,
                                        std::string::allocator_type>
         inBuffer;
-    std::vector<std::string> outBuffer;
+
+    boost::beast::multi_buffer outBuffer;
     bool doingWrite = false;
 
     std::function<void(Connection&)> openHandler;
     std::function<void(Connection&, const std::string&, bool)> messageHandler;
+    std::function<void(crow::websocket::Connection&, std::string_view,
+                       crow::websocket::MessageType type,
+                       std::function<void()>&& whenComplete)>
+        messageExHandler;
     std::function<void(Connection&, const std::string&)> closeHandler;
     std::function<void(Connection&)> errorHandler;
     std::shared_ptr<persistent_data::UserSession> session;
+
+    std::shared_ptr<Connection> selfOwned;
 };
 } // namespace websocket
 } // namespace crow
diff --git a/include/dbus_utility.hpp b/include/dbus_utility.hpp
index 854c2c1..73952a2 100644
--- a/include/dbus_utility.hpp
+++ b/include/dbus_utility.hpp
@@ -16,6 +16,7 @@
 #pragma once
 
 #include "dbus_singleton.hpp"
+#include "logging.hpp"
 
 #include <boost/system/error_code.hpp> // IWYU pragma: keep
 #include <sdbusplus/asio/property.hpp>
@@ -104,6 +105,14 @@
     std::regex_replace(path.begin(), path.begin(), path.end(), reg, "_");
 }
 
+inline void logError(const boost::system::error_code& ec)
+{
+    if (ec)
+    {
+        BMCWEB_LOG_ERROR << "DBus error: " << ec << ", cannot call method";
+    }
+}
+
 // gets the string N strings deep into a path
 // i.e.  /0th/1st/2nd/3rd
 inline bool getNthStringFromPath(const std::string& path, int index,
diff --git a/include/nbd_proxy.hpp b/include/nbd_proxy.hpp
index c0b2907..e3d4c4b 100644
--- a/include/nbd_proxy.hpp
+++ b/include/nbd_proxy.hpp
@@ -18,14 +18,14 @@
 #include "dbus_utility.hpp"
 #include "privileges.hpp"
 
-#include <boost/asio/buffer.hpp>
 #include <boost/asio/local/stream_protocol.hpp>
 #include <boost/asio/write.hpp>
 #include <boost/beast/core/buffers_to_string.hpp>
-#include <boost/beast/core/multi_buffer.hpp>
 #include <boost/container/flat_map.hpp>
 #include <websocket.hpp>
 
+#include <string_view>
+
 namespace crow
 {
 
@@ -34,7 +34,7 @@
 
 using boost::asio::local::stream_protocol;
 
-static constexpr auto nbdBufferSize = 131088;
+static constexpr size_t nbdBufferSize = 131088;
 constexpr const char* requiredPrivilegeString = "ConfigureManager";
 
 struct NbdProxyServer : std::enable_shared_from_this<NbdProxyServer>
@@ -44,6 +44,8 @@
                    const std::string& endpointIdIn, const std::string& pathIn) :
         socketId(socketIdIn),
         endpointId(endpointIdIn), path(pathIn),
+
+        peerSocket(connIn.getIoContext()),
         acceptor(connIn.getIoContext(), stream_protocol::endpoint(socketId)),
         connection(connIn)
     {}
@@ -56,6 +58,17 @@
     ~NbdProxyServer()
     {
         BMCWEB_LOG_DEBUG << "NbdProxyServer destructor";
+
+        BMCWEB_LOG_DEBUG << "peerSocket->close()";
+        boost::system::error_code ec;
+        peerSocket.close(ec);
+
+        BMCWEB_LOG_DEBUG << "std::remove(" << socketId << ")";
+        std::remove(socketId.c_str());
+
+        crow::connections::systemBus->async_method_call(
+            dbus::utility::logError, "xyz.openbmc_project.VirtualMedia", path,
+            "xyz.openbmc_project.VirtualMedia.Proxy", "Unmount");
     }
 
     std::string getEndpointId() const
@@ -65,41 +78,42 @@
 
     void run()
     {
-        acceptor.async_accept([this, self(shared_from_this())](
-                                  const boost::system::error_code& ec,
-                                  stream_protocol::socket socket) {
+        acceptor.async_accept(
+            [weak(weak_from_this())](const boost::system::error_code& ec,
+                                     stream_protocol::socket socket) {
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX socket: async_accept error = "
                                  << ec.message();
                 return;
             }
-            if (peerSocket)
+
+            BMCWEB_LOG_DEBUG << "Connection opened";
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
             {
-                // Something is wrong - socket shouldn't be acquired at this
-                // point
-                BMCWEB_LOG_ERROR
-                    << "Failed to open connection - socket already used";
                 return;
             }
 
-            BMCWEB_LOG_DEBUG << "Connection opened";
-            peerSocket = std::move(socket);
-            doRead();
-
-            // Trigger Write if any data was sent from server
-            // Initially this is negotiation chunk
-            doWrite();
+            self->connection.resumeRead();
+            self->peerSocket = std::move(socket);
+            //  Start reading from socket
+            self->doRead();
         });
 
-        auto mountHandler =
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, const bool) {
+        auto mountHandler = [weak(weak_from_this())](
+                                const boost::system::error_code& ec, bool) {
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
+                return;
+            }
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "DBus error: cannot call mount method = "
                                  << ec.message();
-                connection.close("Failed to mount media");
+
+                self->connection.close("Failed to mount media");
                 return;
             }
         };
@@ -109,90 +123,53 @@
             "xyz.openbmc_project.VirtualMedia.Proxy", "Mount");
     }
 
-    void send(std::string_view data)
+    void send(std::string_view buffer, std::function<void()>&& onDone)
     {
-        boost::asio::buffer_copy(ws2uxBuf.prepare(data.size()),
-                                 boost::asio::buffer(data));
-        ws2uxBuf.commit(data.size());
-        doWrite();
-    }
+        boost::asio::buffer_copy(ws2uxBuf.prepare(buffer.size()),
+                                 boost::asio::buffer(buffer));
+        ws2uxBuf.commit(buffer.size());
 
-    void close()
-    {
-        acceptor.close();
-        if (peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "peerSocket->close()";
-            peerSocket->close();
-            peerSocket.reset();
-            BMCWEB_LOG_DEBUG << "std::remove(" << socketId << ")";
-            std::remove(socketId.c_str());
-        }
-        // The reference to session should exists until unmount is
-        // called
-        auto unmountHandler = [](const boost::system::error_code& ec) {
-            if (ec)
-            {
-                BMCWEB_LOG_ERROR << "DBus error: " << ec
-                                 << ", cannot call unmount method";
-                return;
-            }
-        };
-
-        crow::connections::systemBus->async_method_call(
-            std::move(unmountHandler), "xyz.openbmc_project.VirtualMedia", path,
-            "xyz.openbmc_project.VirtualMedia.Proxy", "Unmount");
+        doWrite(std::move(onDone));
     }
 
   private:
     void doRead()
     {
-        if (!peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "UNIX socket isn't created yet";
-            // Skip if UNIX socket is not created yet.
-            return;
-        }
-
         // Trigger async read
-        peerSocket->async_read_some(
+        peerSocket.async_read_some(
             ux2wsBuf.prepare(nbdBufferSize),
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, std::size_t bytesRead) {
+            [weak(weak_from_this())](const boost::system::error_code& ec,
+                                     size_t bytesRead) {
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX socket: async_read_some error = "
                                  << ec.message();
-                // UNIX socket has been closed by peer, best we can do is to
-                // break all connections
-                close();
+                return;
+            }
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
                 return;
             }
 
-            // Fetch data from UNIX socket
-
-            ux2wsBuf.commit(bytesRead);
-
-            // Paste it to WebSocket as binary
-            connection.sendBinary(
-                boost::beast::buffers_to_string(ux2wsBuf.data()));
-            ux2wsBuf.consume(bytesRead);
-
-            // Allow further reads
-            doRead();
+            // Send to websocket
+            self->ux2wsBuf.commit(bytesRead);
+            self->connection.sendEx(
+                crow::websocket::MessageType::Binary,
+                boost::beast::buffers_to_string(self->ux2wsBuf.data()),
+                [weak(self->weak_from_this())]() {
+                std::shared_ptr<NbdProxyServer> self2 = weak.lock();
+                if (self2 != nullptr)
+                {
+                    self2->ux2wsBuf.consume(self2->ux2wsBuf.size());
+                    self2->doRead();
+                }
+                });
             });
     }
 
-    void doWrite()
+    void doWrite(std::function<void()>&& onDone)
     {
-        if (!peerSocket)
-        {
-            BMCWEB_LOG_DEBUG << "UNIX socket isn't created yet";
-            // Skip if UNIX socket is not created yet. Collect data, and wait
-            // for nbd-client connection
-            return;
-        }
-
         if (uxWriteInProgress)
         {
             BMCWEB_LOG_ERROR << "Write in progress";
@@ -206,23 +183,35 @@
         }
 
         uxWriteInProgress = true;
-        boost::asio::async_write(
-            *peerSocket, ws2uxBuf.data(),
-            [this, self(shared_from_this())](
-                const boost::system::error_code& ec, std::size_t bytesWritten) {
-            ws2uxBuf.consume(bytesWritten);
-            uxWriteInProgress = false;
+        peerSocket.async_write_some(
+            ws2uxBuf.data(),
+            [weak(weak_from_this()),
+             onDone(std::move(onDone))](const boost::system::error_code& ec,
+                                        size_t bytesWritten) mutable {
+            std::shared_ptr<NbdProxyServer> self = weak.lock();
+            if (self == nullptr)
+            {
+                return;
+            }
+
+            self->ws2uxBuf.consume(bytesWritten);
+            self->uxWriteInProgress = false;
+
             if (ec)
             {
                 BMCWEB_LOG_ERROR << "UNIX: async_write error = "
                                  << ec.message();
+                self->connection.close("Internal error");
                 return;
             }
+
             // Retrigger doWrite if there is something in buffer
-            if (ws2uxBuf.size() > 0)
+            if (self->ws2uxBuf.size() > 0)
             {
-                doWrite();
+                self->doWrite(std::move(onDone));
+                return;
             }
+            onDone();
             });
     }
 
@@ -234,17 +223,17 @@
     bool uxWriteInProgress = false;
 
     // UNIX => WebSocket buffer
-    boost::beast::multi_buffer ux2wsBuf;
+    boost::beast::flat_static_buffer<nbdBufferSize> ux2wsBuf;
 
-    // WebSocket <= UNIX buffer
-    boost::beast::multi_buffer ws2uxBuf;
+    // WebSocket => UNIX buffer
+    boost::beast::flat_static_buffer<nbdBufferSize> ws2uxBuf;
+
+    // The socket used to communicate with the client.
+    stream_protocol::socket peerSocket;
 
     // Default acceptor for UNIX socket
     stream_protocol::acceptor acceptor;
 
-    // The socket used to communicate with the client.
-    std::optional<stream_protocol::socket> peerSocket;
-
     crow::websocket::Connection& connection;
 };
 
@@ -343,11 +332,15 @@
                 const dbus::utility::ManagedObjectType& objects) {
         afterGetManagedObjects(conn, ec, objects);
     };
-
     crow::connections::systemBus->async_method_call(
         std::move(openHandler), "xyz.openbmc_project.VirtualMedia",
         "/xyz/openbmc_project/VirtualMedia",
         "org.freedesktop.DBus.ObjectManager", "GetManagedObjects");
+
+    // We need to wait for dbus and the websockets to hook up before data is
+    // sent/received.  Tell the core to hold off messages until the sockets are
+    // up
+    conn.deferRead();
 }
 
 inline void onClose(crow::websocket::Connection& conn,
@@ -360,25 +353,25 @@
         BMCWEB_LOG_DEBUG << "No session to close";
         return;
     }
-    session->second->close();
     // Remove reference to session in global map
     sessions.erase(session);
 }
 
-inline void onMessage(crow::websocket::Connection& conn,
-                      const std::string& data, bool /*isBinary*/)
+inline void onMessage(crow::websocket::Connection& conn, std::string_view data,
+                      crow::websocket::MessageType /*type*/,
+                      std::function<void()>&& whenComplete)
 {
-    BMCWEB_LOG_DEBUG << "nbd-proxy.onmessage(len = " << data.length() << ")";
+    BMCWEB_LOG_DEBUG << "nbd-proxy.onMessage(len = " << data.size() << ")";
+
     // Acquire proxy from sessions
     auto session = sessions.find(&conn);
-    if (session != sessions.end())
+    if (session == sessions.end() || session->second == nullptr)
     {
-        if (session->second)
-        {
-            session->second->send(data);
-            return;
-        }
+        whenComplete();
+        return;
     }
+
+    session->second->send(data, std::move(whenComplete));
 }
 
 inline void requestRoutes(App& app)
@@ -387,7 +380,7 @@
         .websocket()
         .onopen(onOpen)
         .onclose(onClose)
-        .onmessage(onMessage);
+        .onmessageex(onMessage);
 }
 } // namespace nbd_proxy
 } // namespace crow