diff --git a/core/data/stream/ChunkedBuffer.hpp b/core/data/stream/ChunkedBuffer.hpp index 0d8db1bd..84ed3a7e 100644 --- a/core/data/stream/ChunkedBuffer.hpp +++ b/core/data/stream/ChunkedBuffer.hpp @@ -149,9 +149,14 @@ public: os::io::Library::v_size pos, os::io::Library::v_size count); - oatpp::String getSubstring(os::io::Library::v_size pos, - os::io::Library::v_size count); + /** + * return substring of the data written to stream; NOT NULL + */ + oatpp::String getSubstring(os::io::Library::v_size pos, os::io::Library::v_size count); + /** + * return data written to stream as oatpp::String; NOT NULL + */ oatpp::String toString() { return getSubstring(0, m_size); } diff --git a/web/protocol/websocket/WebSocket.cpp b/web/protocol/websocket/WebSocket.cpp index e1fcd5f2..6e125ac9 100644 --- a/web/protocol/websocket/WebSocket.cpp +++ b/web/protocol/websocket/WebSocket.cpp @@ -26,8 +26,55 @@ namespace oatpp { namespace web { namespace protocol { namespace websocket { +void WebSocket::packHeaderBits(v_word16& bits, const FrameHeader& frameHeader, v_word8& messageLengthScenario) const { -void WebSocket::readFrameHeader(FrameHeader& frameHeader) { + bits = 0; + + if(frameHeader.fin) bits |= 32768; + if(frameHeader.rsv1) bits |= 16384; + if(frameHeader.rsv2) bits |= 8192; + if(frameHeader.rsv3) bits |= 4096; + + bits |= (frameHeader.opcode & 15) << 8; + + if(frameHeader.hasMask) bits |= 128; + + if(frameHeader.payloadLength < 126) { + bits |= frameHeader.payloadLength & 127; + messageLengthScenario = 1; + } else if(frameHeader.payloadLength < 65536) { + bits |= 126; + messageLengthScenario = 2; + } else { + bits |= 127; // frameHeader.payloadLength > 65535 + messageLengthScenario = 3; + } + +} + +void WebSocket::unpackHeaderBits(v_word16 bits, FrameHeader& frameHeader, v_word8& messageLen1) const { + frameHeader.fin = (bits & 32768) > 0; // 32768 + frameHeader.rsv1 = (bits & 16384) > 0; // 16384 + frameHeader.rsv2 = (bits & 8192) > 0; // 8192 + frameHeader.rsv3 = (bits & 4096) > 0; // 4096 + frameHeader.opcode = (bits & 3840) >> 8; + frameHeader.hasMask = (bits & 128) > 0; + messageLen1 = (bits & 127); +} + +bool WebSocket::checkForContinuation(const FrameHeader& frameHeader) { + if(m_lastOpcode == OPCODE_TEXT || m_lastOpcode == OPCODE_BINARY) { + return false; + } + if(frameHeader.fin) { + m_lastOpcode = -1; + } else { + m_lastOpcode = frameHeader.opcode; + } + return true; +} + +void WebSocket::readFrameHeader(FrameHeader& frameHeader) const { v_word16 bb; auto res = oatpp::data::stream::readExactSizeData(m_connection.get(), &bb, 2); @@ -35,15 +82,8 @@ void WebSocket::readFrameHeader(FrameHeader& frameHeader) { throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::readFrameHeader()]: Error reading frame header"); } - bb = ntohs(bb); - - frameHeader.fin = (bb & 32768) > 0; // 32768 - frameHeader.rsv1 = (bb & 16384) > 0; // 16384 - frameHeader.rsv2 = (bb & 8192) > 0; // 8192 - frameHeader.rsv3 = (bb & 4096) > 0; // 4096 - frameHeader.opcode = (bb & 4095) >> 8; - frameHeader.hasMask = (bb & 128) > 0; - v_word8 messageLen1 = (bb & 127); + v_word8 messageLen1; + unpackHeaderBits(ntohs(bb), frameHeader, messageLen1); if(messageLen1 < 126) { frameHeader.payloadLength = messageLen1; @@ -76,103 +116,276 @@ void WebSocket::readFrameHeader(FrameHeader& frameHeader) { frameHeader.hasMask, frameHeader.payloadLength); + OATPP_LOGD("WebSocket", "rsv1=%d, rsv2=%d, rsv3=%d", + frameHeader.rsv1, + frameHeader.rsv2, + frameHeader.rsv3); + } -void WebSocket::readPayload(const FrameHeader& frameHeader, bool callListener) { +void WebSocket::writeFrameHeader(const FrameHeader& frameHeader) const { + + v_word16 bb; + v_word8 messageLengthScenario; + packHeaderBits(bb, frameHeader, messageLengthScenario); + + bb = htons(bb); + + auto res = oatpp::data::stream::writeExactSizeData(m_connection.get(), &bb, 2); + if(res != 2) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::writeFrameHeader()]: Error writing frame header"); + } + + if(messageLengthScenario == 2) { + v_word16 messageLen2 = htons(frameHeader.payloadLength); + res = oatpp::data::stream::writeExactSizeData(m_connection.get(), &messageLen2, 2); + if(res != 2) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::writeFrameHeader()]: Error writing frame header. Writing payload length scenario 2."); + } + } else if(messageLengthScenario == 3) { + v_word32 messageLen3[2]; + messageLen3[0] = htonl(frameHeader.payloadLength >> 32); + messageLen3[1] = htonl(frameHeader.payloadLength & 0xFFFFFFFF); + res = oatpp::data::stream::writeExactSizeData(m_connection.get(), &messageLen3, 8); + if(res != 8) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::writeFrameHeader()]: Error writing frame header. Writing payload length scenario 3."); + } + } + + if(frameHeader.hasMask) { + res = oatpp::data::stream::writeExactSizeData(m_connection.get(), frameHeader.mask, 4); + if(res != 4) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::writeFrameHeader()]: Error writing frame header. Writing mask."); + } + } + +} + +void WebSocket::readPayload(const FrameHeader& frameHeader, oatpp::data::stream::ChunkedBuffer* shortMessageStream) const { + + if(shortMessageStream && frameHeader.payloadLength > 125) { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::readPayload()]: Invalid payloadLength. See RFC-6455, section-5.5."); + } v_char8 buffer[oatpp::data::buffer::IOBuffer::BUFFER_SIZE]; oatpp::os::io::Library::v_size progress = 0; + while (progress < frameHeader.payloadLength) { + oatpp::os::io::Library::v_size desiredSize = oatpp::data::buffer::IOBuffer::BUFFER_SIZE; if(desiredSize > frameHeader.payloadLength - progress) { desiredSize = frameHeader.payloadLength - progress; } + auto res = m_connection->read(buffer, desiredSize); + if(res > 0) { - if(callListener && m_listener) { - /* decode message and call listener */ - v_char8 decoded[res]; - for(v_int32 i = 0; i < res; i ++) { - decoded[i] = buffer[i] ^ frameHeader.mask[(i + progress) % 4]; - } + + v_char8 decoded[res]; + for(v_int32 i = 0; i < res; i ++) { + decoded[i] = buffer[i] ^ frameHeader.mask[(i + progress) % 4]; + } + if(shortMessageStream) { + shortMessageStream->write(decoded, res); + } else if(m_listener) { m_listener->readMessage(*this, decoded, res); } progress += res; + }else { // if res == 0 then probably stream handles read() error incorrectly. trow. + if(res == oatpp::data::stream::Errors::ERROR_IO_RETRY || res == oatpp::data::stream::Errors::ERROR_IO_WAIT_RETRY) { continue; } throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::readPayload()]: Invalid connection state."); + } } /* call listener to inform abount messge end */ - if(callListener && frameHeader.fin && m_listener) { + if(shortMessageStream == nullptr && frameHeader.fin && m_listener) { m_listener->readMessage(*this, nullptr, 0); } } -void WebSocket::handleFrame(v_int32 opcode, const FrameHeader& frameHeader) { +void WebSocket::handleFrame(const FrameHeader& frameHeader) { - switch (opcode) { + switch (frameHeader.opcode) { case OPCODE_CONTINUATION: if(m_lastOpcode < 0) { throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Invalid communication state."); } - handleFrame(m_lastOpcode, frameHeader); - return; // return here + readPayload(frameHeader, nullptr); + break; case OPCODE_TEXT: - readPayload(frameHeader, true); + if(checkForContinuation(frameHeader)) { + readPayload(frameHeader, nullptr); + } else { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Invalid communication state. OPCODE_CONTINUATION expected"); + } break; case OPCODE_BINARY: - readPayload(frameHeader, true); + if(checkForContinuation(frameHeader)) { + readPayload(frameHeader, nullptr); + } else { + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Invalid communication state. OPCODE_CONTINUATION expected"); + } break; - case OPCODE_CONNECTION_CLOSE: - readPayload(frameHeader, false); - if(m_listener) { - m_listener->onConnectionClose(*this); + case OPCODE_CLOSE: + { + oatpp::data::stream::ChunkedBuffer messageStream; + readPayload(frameHeader, &messageStream); + if(m_listener) { + v_word16 code = 0; + oatpp::String message; + if(messageStream.getSize() >= 2) { + messageStream.readSubstring(&code, 0, 2); + code = ntohs(code); + message = messageStream.getSubstring(2, messageStream.getSize() - 2); + } + if(!message) { + message = ""; + } + m_listener->onClose(*this, code, message); + } } break; case OPCODE_PING: - readPayload(frameHeader, false); - if(m_listener) { - m_listener->onPing(*this); + { + oatpp::data::stream::ChunkedBuffer messageStream; + readPayload(frameHeader, &messageStream); + if(m_listener) { + m_listener->onPing(*this, messageStream.toString()); + } } break; case OPCODE_PONG: - readPayload(frameHeader, false); - if(m_listener) { - m_listener->onPong(*this); + { + oatpp::data::stream::ChunkedBuffer messageStream; + readPayload(frameHeader, &messageStream); + if(m_listener) { + m_listener->onPong(*this, messageStream.toString()); + } } break; default: - OATPP_LOGD("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]", "Unknown frame"); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::handleFrame()]: Unknown frame"); break; } - m_lastOpcode = opcode; +} +void WebSocket::iterateFrame() { + FrameHeader frameHeader; + readFrameHeader(frameHeader); + handleFrame(frameHeader); } void WebSocket::listen() { + m_listening = true; + try { FrameHeader frameHeader; do { readFrameHeader(frameHeader); - handleFrame(frameHeader.opcode, frameHeader); - } while(frameHeader.opcode != OPCODE_CONNECTION_CLOSE); + handleFrame(frameHeader); + } while(frameHeader.opcode != OPCODE_CLOSE && m_listening); } catch(...) { OATPP_LOGD("[oatpp::web::protocol::websocket::WebSocket::listen()]", "Unhandled error occurred"); } } +void WebSocket::stopListening() const { + m_listening = false; +} + +void WebSocket::sendFrame(bool fin, v_word8 opcode, v_int64 messageSize) const { + + oatpp::web::protocol::websocket::WebSocket::FrameHeader frame; + frame.fin = fin; + frame.rsv1 = false; + frame.rsv2 = false; + frame.rsv3 = false; + frame.opcode = opcode; + frame.hasMask = false; + frame.payloadLength = messageSize; + + writeFrameHeader(frame); + +} + +bool WebSocket::sendOneFrameMessage(v_word8 opcode, const oatpp::String& message) const { + if(message && message->getSize() > 0) { + sendFrame(true, opcode, message->getSize()); + auto res = oatpp::data::stream::writeExactSizeData(m_connection.get(), message->getData(), message->getSize()); + if(res != message->getSize()) { + return false; + } + } else { + sendFrame(true, opcode, 0); + } + return true; +} + +void WebSocket::sendClose(v_word16 code, const oatpp::String& message) const { + + code = htons(code); + + oatpp::data::stream::ChunkedBuffer buffer; + buffer.write(&code, 2); + if(message) { + buffer.write(message->getData(), message->getSize()); + } + + if(!sendOneFrameMessage(OPCODE_CLOSE, buffer.toString())) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendClose(...)]: Unknown error while writing to socket."); + } + +} + +void WebSocket::sendClose() const { + if(!sendOneFrameMessage(OPCODE_CLOSE, nullptr)) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendClose()]: Unknown error while writing to socket."); + } +} + +void WebSocket::sendPing(const oatpp::String& message) const { + if(!sendOneFrameMessage(OPCODE_PING, message)) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendPing()]: Unknown error while writing to socket."); + } +} + +void WebSocket::sendPong(const oatpp::String& message) const { + if(!sendOneFrameMessage(OPCODE_PONG, message)) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendPong()]: Unknown error while writing to socket."); + } +} + +void WebSocket::sendOneFrameText(const oatpp::String& message) const { + if(!sendOneFrameMessage(OPCODE_TEXT, message)) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendOneFrameText()]: Unknown error while writing to socket."); + } +} + +void WebSocket::sendOneFrameBinary(const oatpp::String& message) const { + if(!sendOneFrameMessage(OPCODE_BINARY, message)) { + stopListening(); + throw std::runtime_error("[oatpp::web::protocol::websocket::WebSocket::sendOneFrameBinary()]: Unknown error while writing to socket."); + } +} + }}}} diff --git a/web/protocol/websocket/WebSocket.hpp b/web/protocol/websocket/WebSocket.hpp index 954ba1ca..810f46a9 100644 --- a/web/protocol/websocket/WebSocket.hpp +++ b/web/protocol/websocket/WebSocket.hpp @@ -25,7 +25,7 @@ #ifndef oatpp_web_protocol_websocket_WebSocket_hpp #define oatpp_web_protocol_websocket_WebSocket_hpp -#include "oatpp/core/data/stream/Stream.hpp" +#include "oatpp/core/data/stream/ChunkedBuffer.hpp" namespace oatpp { namespace web { namespace protocol { namespace websocket { @@ -35,7 +35,7 @@ public: static constexpr v_word8 OPCODE_CONTINUATION = 0x0; static constexpr v_word8 OPCODE_TEXT = 0x1; static constexpr v_word8 OPCODE_BINARY = 0x2; - static constexpr v_word8 OPCODE_CONNECTION_CLOSE = 0x8; + static constexpr v_word8 OPCODE_CLOSE = 0x8; static constexpr v_word8 OPCODE_PING = 0x9; static constexpr v_word8 OPCODE_PONG = 0xA; @@ -47,28 +47,23 @@ public: typedef oatpp::web::protocol::websocket::WebSocket WebSocket; public: - /** - * Called when WebSocket is connected to client/server - */ - virtual void onConnected(const WebSocket& webSocket) = 0; - /** * Called when "ping" frame received */ - virtual void onPing(const WebSocket& webSocket) = 0; + virtual void onPing(const WebSocket& webSocket, const oatpp::String& message) = 0; /** * Called when "pong" frame received */ - virtual void onPong(const WebSocket& webSocket) = 0; + virtual void onPong(const WebSocket& webSocket, const oatpp::String& message) = 0; /** - * Called when "connection close" frame received + * Called when "close" frame received */ - virtual void onConnectionClose(const WebSocket& webSocket) = 0; + virtual void onClose(const WebSocket& webSocket, v_word16 code, const oatpp::String& message) = 0; /** - * Called when "message" frame received. + * Called when "text" or "binary" frame received. * When all data of message is read, readMessage is called again with size == 0 to * indicate end of the message */ @@ -90,27 +85,110 @@ public: }; private: - void readFrameHeader(FrameHeader& frameHeader); - void handleFrame(v_int32 opcode, const FrameHeader& frameHeader); - void readPayload(const FrameHeader& frameHeader, bool callListener); + + void packHeaderBits(v_word16& bits, const FrameHeader& frameHeader, v_word8& messageLengthScenario) const; + void unpackHeaderBits(v_word16 bits, FrameHeader& frameHeader, v_word8& messageLen1) const; + + bool checkForContinuation(const FrameHeader& frameHeader); + void readFrameHeader(FrameHeader& frameHeader) const; + void handleFrame(const FrameHeader& frameHeader); + + /** + * if(shortMessageStream == nullptr) - read call readMessage() method of listener + * if(shortMessageStream) - read message to shortMessageStream. Don't call listener + */ + void readPayload(const FrameHeader& frameHeader, oatpp::data::stream::ChunkedBuffer* shortMessageStream) const; + private: std::shared_ptr m_connection; std::shared_ptr m_listener; v_int32 m_lastOpcode; + mutable bool m_listening; public: WebSocket(const std::shared_ptr& connection) : m_connection(connection) , m_listener(nullptr) , m_lastOpcode(-1) + , m_listening(false) {} + std::shared_ptr getConnection() const { + return m_connection; + } + void setListener(const std::shared_ptr& listener) { m_listener = listener; } + /** + * Use this method if you know what you are doing. + * Read one frame from connection and call corresponding methods of listener. + * See WebSocket::setListener() + */ + void iterateFrame(); + + /** + * Blocks until stopListening() is called or error occurred + * Read incoming frames and call corresponding methods of listener. + * See WebSocket::setListener() + */ void listen(); + /** + * Break listen loop. See WebSocket::listen() + */ + void stopListening() const; + + /** + * Use this method if you know what you are doing. + * Send custom frame to peer. + */ + void writeFrameHeader(const FrameHeader& frameHeader) const; + + /** + * Use this method if you know what you are doing. + * Send default frame to peer with fin, opcode and messageSize set + */ + void sendFrame(bool fin, v_word8 opcode, v_int64 messageSize) const; + + /** + * Send one frame message with custom opcode + * return true on success, false on error. + * if false returned socket should be closed manually + */ + bool sendOneFrameMessage(v_word8 opcode, const oatpp::String& message) const; + + /** + * throws on error and closes socket + */ + void sendClose(v_word16 code, const oatpp::String& message) const; + + /** + * throws on error and closes socket + */ + void sendClose() const; + + /** + * throws on error and closes socket + */ + void sendPing(const oatpp::String& message) const; + + /** + * throws on error and closes socket + */ + void sendPong(const oatpp::String& message) const; + + /** + * throws on error and closes socket + */ + void sendOneFrameText(const oatpp::String& message) const; + + /** + * throws on error and closes socket + */ + void sendOneFrameBinary(const oatpp::String& message) const; + }; }}}}