Uwaa/HTTP/Websockets/Websocket.cs

289 lines
No EOL
9.2 KiB
C#

using System.Buffers;
using System.Buffers.Binary;
using System.Text;
namespace Uwaa.HTTP.Websockets;
/// <summary>
/// A websocket wrapper over a HTTP stream.
/// </summary>
public class Websocket
{
/// <summary>
/// The chosen sub-protocol negotiated with the remote endpoint.
/// </summary>
public readonly string? SubProtocol;
internal readonly HttpStream Stream;
readonly List<byte> finalPayload = new List<byte>();
WSOpcode currentOpcode;
internal Websocket(HttpStream stream, string? subProtocol)
{
Stream = stream;
SubProtocol = subProtocol;
}
/// <summary>
/// Reads a data frame from the websocket.
/// </summary>
/// <returns></returns>
/// <exception cref="EndOfStreamException">Thrown when the stream stops unexpectedly.</exception>
/// <exception cref="IOException">Thrown when the client declares a payload which is too large.</exception>
public async Task<DataFrame> Read()
{
var pool = ArrayPool<byte>.Shared;
byte[] recvBuffer = pool.Rent(10);
try
{
while (true)
{
//First byte
if (await Stream.Read(recvBuffer.AsMemory(0, 2)) < 2)
throw new EndOfStreamException();
byte firstByte = recvBuffer[0];
bool fin = (firstByte & 1) != 0;
//bool rsv1 = (firstByte & 2) != 0;
//bool rsv2 = (firstByte & 4) != 0;
//bool rsv3 = (firstByte & 8) != 0;
WSOpcode opcode = (WSOpcode)(firstByte & 0b00001111);
//Second byte
byte secondByte = recvBuffer[1];
bool maskEnabled = (secondByte & 0b10000000) != 0;
if (maskEnabled)
secondByte &= 0b01111111;
//Payload length
uint payloadLength;
if (secondByte < 126)
{
payloadLength = secondByte;
}
else
{
if (secondByte == 126)
{
if (await Stream.Read(recvBuffer.AsMemory(0, 2)) < 2)
throw new EndOfStreamException();
payloadLength = BinaryPrimitives.ReadUInt16BigEndian(recvBuffer);
}
else if (secondByte == 127)
{
if (await Stream.Read(recvBuffer.AsMemory(0, 8)) < 8)
throw new EndOfStreamException();
payloadLength = BinaryPrimitives.ReadUInt32BigEndian(recvBuffer);
}
else
{
throw new Exception("This shouldn't happen");
}
}
if (finalPayload.Count + payloadLength > 100_000)
throw new IOException("Payload too large");
//Mask
byte maskKey1, maskKey2, maskKey3, maskKey4;
if (maskEnabled)
{
if (await Stream.Read(recvBuffer.AsMemory(0, 4)) < 4)
throw new EndOfStreamException();
maskKey1 = recvBuffer[0];
maskKey2 = recvBuffer[1];
maskKey3 = recvBuffer[2];
maskKey4 = recvBuffer[3];
}
else
{
maskKey1 = 0;
maskKey2 = 0;
maskKey3 = 0;
maskKey4 = 0;
}
//Payload
byte[] payloadBuffer = pool.Rent((int)payloadLength);
try
{
ArraySegment<byte> payload = new ArraySegment<byte>(payloadBuffer, 0, (int)payloadLength);
if (await Stream.Read(payload) < payloadLength)
throw new EndOfStreamException();
//Unmask payload
//TODO: Optimize using unsafe
if (maskEnabled)
{
int index = 0;
while (true)
{
if (index >= payloadLength)
break;
payload[index] = (byte)(payload[index] ^ maskKey1);
index++;
if (index >= payloadLength)
break;
payload[index] = (byte)(payload[index] ^ maskKey2);
index++;
if (index >= payloadLength)
break;
payload[index] = (byte)(payload[index] ^ maskKey3);
index++;
if (index >= payloadLength)
break;
payload[index] = (byte)(payload[index] ^ maskKey4);
index++;
}
}
switch (opcode)
{
case WSOpcode.Close:
await Write(new DataFrame(WSOpcode.Close, true, Array.Empty<byte>()));
return new DataFrame(WSOpcode.Close, true, FlushPayload());
case WSOpcode.Continuation:
case WSOpcode.Text:
case WSOpcode.Binary:
{
if (opcode is WSOpcode.Text or WSOpcode.Binary)
currentOpcode = opcode;
finalPayload.AddRange(payload);
if (fin)
return new DataFrame(currentOpcode, fin, FlushPayload());
else
break;
}
case WSOpcode.Ping:
await Write(new DataFrame(WSOpcode.Pong, true, payload));
break;
}
}
finally
{
pool.Return(payloadBuffer);
}
}
}
finally
{
pool.Return(recvBuffer);
}
}
byte[] FlushPayload()
{
byte[] final = finalPayload.ToArray();
finalPayload.Clear();
return final;
}
public Task Write(byte[] payload)
{
return Write(new DataFrame(WSOpcode.Binary, true, payload));
}
public Task Write(string payload)
{
return Write(new DataFrame(WSOpcode.Text, true, Encoding.UTF8.GetBytes(payload)));
}
public async Task Write(DataFrame frame)
{
var pool = ArrayPool<byte>.Shared;
byte[] writeBuf = pool.Rent(10);
try
{
byte firstByte = 0;
if (frame.EndOfMessage)
firstByte |= 0b10000000;
firstByte |= (byte)((int)frame.Opcode & 0b00001111);
writeBuf[0] = firstByte;
await Stream.Write(writeBuf.AsMemory(0, 1));
if (frame.Payload.Count < 126)
{
writeBuf[0] = (byte)frame.Payload.Count;
await Stream.Write(writeBuf.AsMemory(0, 1));
}
else
{
if (frame.Payload.Count < ushort.MaxValue)
{
writeBuf[0] = 126;
BinaryPrimitives.WriteUInt16BigEndian(writeBuf.AsSpan(1), (ushort)frame.Payload.Count);
await Stream.Write(writeBuf.AsMemory(0, 3));
}
else
{
writeBuf[0] = 127;
BinaryPrimitives.WriteUInt64BigEndian(writeBuf.AsSpan(1), (ulong)frame.Payload.Count);
await Stream.Write(writeBuf.AsMemory(0, 9));
}
}
await Stream.Write(frame.Payload);
await Stream.Flush();
}
finally
{
pool.Return(writeBuf);
}
}
internal async Task Close(CloseStatus status = CloseStatus.NormalClosure)
{
var pool = ArrayPool<byte>.Shared;
byte[] closeBuf = pool.Rent(2);
try
{
BinaryPrimitives.WriteUInt16BigEndian(closeBuf, (ushort)status);
await Write(new DataFrame(WSOpcode.Close, true, new ArraySegment<byte>(closeBuf, 0, 2)));
}
finally
{
pool.Return(closeBuf);
}
}
}
/// <summary>
/// A remote websocket connected to a local HTTP server.
/// </summary>
public class WebsocketRemote : Websocket
{
/// <summary>
/// The HTTP request encompassing the websocket.
/// </summary>
public readonly HttpRequest Request;
/// <summary>
/// The HTTP request encompassing the websocket.
/// </summary>
public readonly HttpClientInfo ClientInfo;
internal WebsocketRemote(HttpRequest request, HttpClientInfo clientInfo, HttpStream stream, string? subProtocol) : base(stream, subProtocol)
{
Request = request;
ClientInfo = clientInfo;
}
}