using System.Buffers; using System.Buffers.Binary; using System.Text; namespace Uwaa.HTTP.Websockets; /// /// A websocket wrapper over a HTTP stream. /// public class Websocket { /// /// The chosen sub-protocol negotiated with the remote endpoint. /// public readonly string? SubProtocol; internal readonly HttpStream Stream; readonly List finalPayload = new List(); WSOpcode currentOpcode; internal Websocket(HttpStream stream, string? subProtocol) { Stream = stream; SubProtocol = subProtocol; } /// /// Reads a data frame from the websocket. /// /// /// Thrown when the stream stops unexpectedly. /// Thrown when the client declares a payload which is too large. public async Task Read() { var pool = ArrayPool.Shared; byte[] recvBuffer = pool.Rent(10); try { while (true) { //First byte if (await Stream.Read(recvBuffer.AsMemory(0, 2)) < 2) throw new EndOfStreamException(); byte firstByte = recvBuffer[0]; bool fin = (firstByte & 1) != 0; //bool rsv1 = (firstByte & 2) != 0; //bool rsv2 = (firstByte & 4) != 0; //bool rsv3 = (firstByte & 8) != 0; WSOpcode opcode = (WSOpcode)(firstByte & 0b00001111); //Second byte byte secondByte = recvBuffer[1]; bool maskEnabled = (secondByte & 0b10000000) != 0; if (maskEnabled) secondByte &= 0b01111111; //Payload length uint payloadLength; if (secondByte < 126) { payloadLength = secondByte; } else { if (secondByte == 126) { if (await Stream.Read(recvBuffer.AsMemory(0, 2)) < 2) throw new EndOfStreamException(); payloadLength = BinaryPrimitives.ReadUInt16BigEndian(recvBuffer); } else if (secondByte == 127) { if (await Stream.Read(recvBuffer.AsMemory(0, 8)) < 8) throw new EndOfStreamException(); payloadLength = BinaryPrimitives.ReadUInt32BigEndian(recvBuffer); } else { throw new Exception("This shouldn't happen"); } } if (finalPayload.Count + payloadLength > 100_000) throw new IOException("Payload too large"); //Mask byte maskKey1, maskKey2, maskKey3, maskKey4; if (maskEnabled) { if (await Stream.Read(recvBuffer.AsMemory(0, 4)) < 4) throw new EndOfStreamException(); maskKey1 = recvBuffer[0]; maskKey2 = recvBuffer[1]; maskKey3 = recvBuffer[2]; maskKey4 = recvBuffer[3]; } else { maskKey1 = 0; maskKey2 = 0; maskKey3 = 0; maskKey4 = 0; } //Payload byte[] payloadBuffer = pool.Rent((int)payloadLength); try { ArraySegment payload = new ArraySegment(payloadBuffer, 0, (int)payloadLength); if (await Stream.Read(payload) < payloadLength) throw new EndOfStreamException(); //Unmask payload //TODO: Optimize using unsafe if (maskEnabled) { int index = 0; while (true) { if (index >= payloadLength) break; payload[index] = (byte)(payload[index] ^ maskKey1); index++; if (index >= payloadLength) break; payload[index] = (byte)(payload[index] ^ maskKey2); index++; if (index >= payloadLength) break; payload[index] = (byte)(payload[index] ^ maskKey3); index++; if (index >= payloadLength) break; payload[index] = (byte)(payload[index] ^ maskKey4); index++; } } switch (opcode) { case WSOpcode.Close: await Write(new DataFrame(WSOpcode.Close, true, Array.Empty())); return new DataFrame(WSOpcode.Close, true, FlushPayload()); case WSOpcode.Continuation: case WSOpcode.Text: case WSOpcode.Binary: { if (opcode is WSOpcode.Text or WSOpcode.Binary) currentOpcode = opcode; finalPayload.AddRange(payload); if (fin) return new DataFrame(currentOpcode, fin, FlushPayload()); else break; } case WSOpcode.Ping: await Write(new DataFrame(WSOpcode.Pong, true, payload)); break; } } finally { pool.Return(payloadBuffer); } } } finally { pool.Return(recvBuffer); } } byte[] FlushPayload() { byte[] final = finalPayload.ToArray(); finalPayload.Clear(); return final; } public Task Write(byte[] payload) { return Write(new DataFrame(WSOpcode.Binary, true, payload)); } public Task Write(string payload) { return Write(new DataFrame(WSOpcode.Text, true, Encoding.UTF8.GetBytes(payload))); } public async Task Write(DataFrame frame) { var pool = ArrayPool.Shared; byte[] writeBuf = pool.Rent(10); try { byte firstByte = 0; if (frame.EndOfMessage) firstByte |= 0b10000000; firstByte |= (byte)((int)frame.Opcode & 0b00001111); writeBuf[0] = firstByte; await Stream.Write(writeBuf.AsMemory(0, 1)); if (frame.Payload.Count < 126) { writeBuf[0] = (byte)frame.Payload.Count; await Stream.Write(writeBuf.AsMemory(0, 1)); } else { if (frame.Payload.Count < ushort.MaxValue) { writeBuf[0] = 126; BinaryPrimitives.WriteUInt16BigEndian(writeBuf.AsSpan(1), (ushort)frame.Payload.Count); await Stream.Write(writeBuf.AsMemory(0, 3)); } else { writeBuf[0] = 127; BinaryPrimitives.WriteUInt64BigEndian(writeBuf.AsSpan(1), (ulong)frame.Payload.Count); await Stream.Write(writeBuf.AsMemory(0, 9)); } } await Stream.Write(frame.Payload); await Stream.Flush(); } finally { pool.Return(writeBuf); } } internal async Task Close(CloseStatus status = CloseStatus.NormalClosure) { var pool = ArrayPool.Shared; byte[] closeBuf = pool.Rent(2); try { BinaryPrimitives.WriteUInt16BigEndian(closeBuf, (ushort)status); await Write(new DataFrame(WSOpcode.Close, true, new ArraySegment(closeBuf, 0, 2))); } finally { pool.Return(closeBuf); } } } /// /// A remote websocket connected to a local HTTP server. /// public class WebsocketRemote : Websocket { /// /// The HTTP request encompassing the websocket. /// public readonly HttpRequest Request; /// /// The HTTP request encompassing the websocket. /// public readonly HttpClientInfo ClientInfo; internal WebsocketRemote(HttpRequest request, HttpClientInfo clientInfo, HttpStream stream, string? subProtocol) : base(stream, subProtocol) { Request = request; ClientInfo = clientInfo; } }