blob: f3cd0e0509c810f3cd9857b1c0aa4bebbcc0814c [file] [log] [blame]
#pragma once
#include <array>
#include "crow/TinySHA1.hpp"
#include "crow/http_request.h"
#include "crow/socket_adaptors.h"
#include <boost/algorithm/string/predicate.hpp>
namespace crow {
namespace websocket {
enum class WebSocketReadState {
MiniHeader,
Len16,
Len64,
Mask,
Payload,
};
struct connection {
public:
explicit connection(const crow::request& req)
: req(req), userdata_(nullptr){};
virtual void send_binary(const std::string& msg) = 0;
virtual void send_text(const std::string& msg) = 0;
virtual void close(const std::string& msg = "quit") = 0;
virtual boost::asio::io_service& get_io_service() = 0;
virtual ~connection() = default;
void userdata(void* u) { userdata_ = u; }
void* userdata() { return userdata_; }
crow::request req;
private:
void* userdata_;
};
template <typename Adaptor>
class Connection : public connection {
public:
Connection(const crow::request& req, Adaptor&& adaptor,
std::function<void(connection&)> open_handler,
std::function<void(connection&, const std::string&, bool)>
message_handler,
std::function<void(connection&, const std::string&)> close_handler,
std::function<void(connection&)> error_handler)
: adaptor_(std::move(adaptor)),
connection(req),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)) {
if (!boost::iequals(req.get_header_value("upgrade"), "websocket")) {
adaptor.close();
delete this;
return;
}
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Sec-WebSocket-Version: 13
std::string magic = req.get_header_value("Sec-WebSocket-Key") +
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1::SHA1 s;
s.processBytes(magic.data(), magic.size());
uint8_t digest[20];
s.getDigestBytes(digest);
start(crow::utility::base64encode(reinterpret_cast<char*>(digest), 20));
}
template <typename CompletionHandler>
void dispatch(CompletionHandler handler) {
adaptor_.get_io_service().dispatch(handler);
}
template <typename CompletionHandler>
void post(CompletionHandler handler) {
adaptor_.get_io_service().post(handler);
}
boost::asio::io_service& get_io_service() override {
return adaptor_.get_io_service();
}
void send_pong(const std::string& msg) {
dispatch([this, msg] {
char buf[3] = "\x8A\x00";
buf[1] += msg.size();
write_buffers_.emplace_back(buf, buf + 2);
write_buffers_.emplace_back(msg);
do_write();
});
}
void send_binary(const std::string& msg) override {
dispatch([this, msg] {
auto header = build_header(2, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
void send_text(const std::string& msg) override {
dispatch([this, msg] {
auto header = build_header(1, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
void close(const std::string& msg) override {
dispatch([this, msg] {
has_sent_close_ = true;
if (has_recv_close_ && !is_close_handler_called_) {
is_close_handler_called_ = true;
if (close_handler_) {
close_handler_(*this, msg);
}
}
auto header = build_header(0x8, msg.size());
write_buffers_.emplace_back(std::move(header));
write_buffers_.emplace_back(msg);
do_write();
});
}
protected:
std::string build_header(int opcode, uint64_t size) {
char buf[2 + 8] = "\x80\x00";
buf[0] += opcode;
if (size < 126) {
buf[1] += size;
return {buf, buf + 2};
} else if (size < 0x10000) {
buf[1] += 126;
*reinterpret_cast<uint16_t*>(buf + 2) =
htons(static_cast<uint16_t>(size));
return {buf, buf + 4};
} else {
buf[1] += 127;
*reinterpret_cast<uint64_t*>(buf + 2) =
((1 == htonl(1))
? size
: (static_cast<uint64_t>(htonl((size)&0xFFFFFFFF)) << 32) |
htonl((size) >> 32));
return {buf, buf + 10};
}
}
void start(std::string&& hello) {
static std::string header =
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
//"Sec-WebSocket-Protocol: binary\r\n" // TODO(ed): this hardcodes
// binary mode
// find a better way
"Sec-WebSocket-Accept: ";
static std::string crlf = "\r\n";
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_) {
open_handler_(*this);
}
do_read();
}
void do_read() {
is_reading = true;
switch (state_) {
case WebSocketReadState::MiniHeader: {
// boost::asio::async_read(adaptor_.socket(),
// boost::asio::buffer(&mini_header_, 1),
adaptor_.socket().async_read_some(
boost::asio::buffer(&mini_header_, 2),
[this](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
is_reading = false;
mini_header_ = htons(mini_header_);
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 2) {
throw std::runtime_error(
"WebSocket:MiniHeader:async_read fail:asio bug?");
}
#endif
if (!ec && ((mini_header_ & 0x80) == 0x80)) {
if ((mini_header_ & 0x7f) == 127) {
state_ = WebSocketReadState::Len64;
} else if ((mini_header_ & 0x7f) == 126) {
state_ = WebSocketReadState::Len16;
} else {
remaining_length_ = mini_header_ & 0x7f;
state_ = WebSocketReadState::Mask;
}
do_read();
} else {
close_connection_ = true;
adaptor_.close();
if (error_handler_) {
error_handler_(*this);
}
check_destroy();
}
});
} break;
case WebSocketReadState::Len16: {
remaining_length_ = 0;
boost::asio::async_read(
adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2),
[this](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
is_reading = false;
remaining_length_ = ntohs(*(uint16_t*)&remaining_length_);
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 2) {
throw std::runtime_error(
"WebSocket:Len16:async_read fail:asio bug?");
}
#endif
if (!ec) {
state_ = WebSocketReadState::Mask;
do_read();
} else {
close_connection_ = true;
adaptor_.close();
if (error_handler_) {
error_handler_(*this);
}
check_destroy();
}
});
} break;
case WebSocketReadState::Len64: {
boost::asio::async_read(
adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
[this](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
is_reading = false;
remaining_length_ =
((1 == ntohl(1))
? (remaining_length_)
: ((uint64_t)ntohl((remaining_length_)&0xFFFFFFFF)
<< 32) |
ntohl((remaining_length_) >> 32));
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 8) {
throw std::runtime_error(
"WebSocket:Len16:async_read fail:asio bug?");
}
#endif
if (!ec) {
state_ = WebSocketReadState::Mask;
do_read();
} else {
close_connection_ = true;
adaptor_.close();
if (error_handler_) {
error_handler_(*this);
}
check_destroy();
}
});
} break;
case WebSocketReadState::Mask:
boost::asio::async_read(
adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4),
[this](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
is_reading = false;
#ifdef CROW_ENABLE_DEBUG
if (!ec && bytes_transferred != 4) {
throw std::runtime_error(
"WebSocket:Mask:async_read fail:asio bug?");
}
#endif
if (!ec) {
state_ = WebSocketReadState::Payload;
do_read();
} else {
close_connection_ = true;
if (error_handler_) {
error_handler_(*this);
}
adaptor_.close();
}
});
break;
case WebSocketReadState::Payload: {
size_t to_read = buffer_.size();
if (remaining_length_ < to_read) {
to_read = remaining_length_;
}
adaptor_.socket().async_read_some(
boost::asio::buffer(buffer_, to_read),
[this](const boost::system::error_code& ec,
std::size_t bytes_transferred) {
is_reading = false;
if (!ec) {
fragment_.insert(fragment_.end(), buffer_.begin(),
buffer_.begin() + bytes_transferred);
remaining_length_ -= bytes_transferred;
if (remaining_length_ == 0) {
handle_fragment();
state_ = WebSocketReadState::MiniHeader;
do_read();
}
} else {
close_connection_ = true;
if (error_handler_) {
error_handler_(*this);
}
adaptor_.close();
}
});
} break;
}
}
bool is_FIN() { return mini_header_ & 0x8000; }
int opcode() { return (mini_header_ & 0x0f00) >> 8; }
void handle_fragment() {
for (decltype(fragment_.length()) i = 0; i < fragment_.length(); i++) {
fragment_[i] ^= ((char*)&mask_)[i % 4];
}
switch (opcode()) {
case 0: // Continuation
{
message_ += fragment_;
if (is_FIN()) {
if (message_handler_) {
message_handler_(*this, message_, is_binary_);
}
message_.clear();
}
}
case 1: // Text
{
is_binary_ = false;
message_ += fragment_;
if (is_FIN()) {
if (message_handler_) {
message_handler_(*this, message_, is_binary_);
}
message_.clear();
}
} break;
case 2: // Binary
{
is_binary_ = true;
message_ += fragment_;
if (is_FIN()) {
if (message_handler_) {
message_handler_(*this, message_, is_binary_);
}
message_.clear();
}
} break;
case 0x8: // Close
{
has_recv_close_ = true;
if (!has_sent_close_) {
close(fragment_);
} else {
adaptor_.close();
close_connection_ = true;
if (!is_close_handler_called_) {
if (close_handler_) {
close_handler_(*this, fragment_);
}
is_close_handler_called_ = true;
}
check_destroy();
}
} break;
case 0x9: // Ping
{
send_pong(fragment_);
} break;
case 0xA: // Pong
{
pong_received_ = true;
} break;
}
fragment_.clear();
}
void do_write() {
if (sending_buffers_.empty()) {
sending_buffers_.swap(write_buffers_);
std::vector<boost::asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_) {
buffers.emplace_back(boost::asio::buffer(s));
}
boost::asio::async_write(adaptor_.socket(), buffers,
[&](const boost::system::error_code& ec,
std::size_t /*bytes_transferred*/) {
sending_buffers_.clear();
if (!ec && !close_connection_) {
if (!write_buffers_.empty()) {
do_write();
}
if (has_sent_close_) {
close_connection_ = true;
}
} else {
close_connection_ = true;
check_destroy();
}
});
}
}
void check_destroy() {
// if (has_sent_close_ && has_recv_close_)
if (!is_close_handler_called_) {
if (close_handler_) {
close_handler_(*this, "uncleanly");
}
}
if (sending_buffers_.empty() && !is_reading) {
delete this;
}
}
private:
Adaptor adaptor_;
std::vector<std::string> sending_buffers_;
std::vector<std::string> write_buffers_;
std::array<char, 4096> buffer_{};
bool is_binary_{};
std::string message_;
std::string fragment_;
WebSocketReadState state_{WebSocketReadState::MiniHeader};
uint64_t remaining_length_{0};
bool close_connection_{false};
bool is_reading{false};
uint32_t mask_{};
uint16_t mini_header_{};
bool has_sent_close_{false};
bool has_recv_close_{false};
bool error_occured_{false};
bool pong_received_{false};
bool is_close_handler_called_{false};
std::function<void(connection&)> open_handler_;
std::function<void(connection&, const std::string&, bool)> message_handler_;
std::function<void(connection&, const std::string&)> close_handler_;
std::function<void(connection&)> error_handler_;
};
} // namespace websocket
} // namespace crow