From 553f96351678f9edb1c82a8d7540df3345013ad6 Mon Sep 17 00:00:00 2001 From: uwaa Date: Fri, 15 Nov 2024 15:16:58 +0000 Subject: [PATCH] cleanups & changes to websockets --- HTTP/HttpRequest.cs | 37 +++++++++------------ HTTP/HttpServer.cs | 19 ++++++++++- HTTP/Responses/SwitchingProtocols.cs | 34 ++++++++++++++++--- HTTP/Websockets/Websocket.cs | 49 +++++++++++++++------------- 4 files changed, 90 insertions(+), 49 deletions(-) diff --git a/HTTP/HttpRequest.cs b/HTTP/HttpRequest.cs index 88cc870..bc2a9bf 100644 --- a/HTTP/HttpRequest.cs +++ b/HTTP/HttpRequest.cs @@ -278,7 +278,7 @@ public sealed class HttpRequest /// /// Writes a response. /// - public async Task Write(HttpResponse response) + internal async Task Write(HttpResponse response) { await response.Write(Stream); await Stream.FlushAsync(); @@ -287,7 +287,7 @@ public sealed class HttpRequest /// /// Attempts to destroy the connection, ignoring all errors. /// - public async void Close(HttpResponse? response) + internal async void Close(HttpResponse? response) { if (response != null && Client.Connected) { @@ -304,46 +304,41 @@ public sealed class HttpRequest } /// - /// Sends a "switching protocol" header for a websocket. + /// Attempts to accept and upgrade a connection to a websocket. /// - public async Task UpgradeToWebsocket(params string[]? protocols) + /// The websocket execution function to call. + /// Subprotocols which can be accepted. If null or empty, any protocol will be accepted. + /// Returns a response. + /// + /// If an upgrade has not been requested or no subprotocol can be negotiated, this will return null. + /// + public SwitchingProtocols? UpgradeToWebsocket(WebsocketHandler callback, params string[]? protocols) { if (!Headers.TryGetValue("Sec-WebSocket-Key", out string? wsKey)) - { - await Write(new BadRequest("Missing Sec-WebSocket-Key header")); return null; - } - - //Increase timeouts - Client.SendTimeout = 120_000; - Client.ReceiveTimeout = 120_000; - - string acceptKey = Convert.ToBase64String(SHA1.HashData(Encoding.ASCII.GetBytes(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))); //Subprotocol negotiation string? chosenProtocol = null; string? requestedProtocols = Headers["Sec-WebSocket-Protocol"]; if (requestedProtocols != null && protocols != null && protocols.Length > 0) { - foreach (string supported in protocols) + foreach (string requested in requestedProtocols.ToLower().Split(',', StringSplitOptions.TrimEntries)) { - foreach (string requested in requestedProtocols.ToLower().Split(',', StringSplitOptions.TrimEntries)) + foreach (string supported in protocols) { - if (requested == supported) + if (requested.Equals(supported, StringComparison.InvariantCultureIgnoreCase)) { chosenProtocol = supported; goto a; } } } - await Write(new NotAcceptable("Unsupported websocket subprotocol")); return null; } - a: - await Write(new SwitchingProtocols(acceptKey, chosenProtocol)); - - return new Websocket(this, chosenProtocol); + a: + string acceptKey = Convert.ToBase64String(SHA1.HashData(Encoding.ASCII.GetBytes(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))); + return new SwitchingProtocols(acceptKey, chosenProtocol, callback); } /// diff --git a/HTTP/HttpServer.cs b/HTTP/HttpServer.cs index 9b6d80d..248a560 100644 --- a/HTTP/HttpServer.cs +++ b/HTTP/HttpServer.cs @@ -4,6 +4,7 @@ using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using Uwaa.HTTP.Responses; using Uwaa.HTTP.Routing; +using Uwaa.HTTP.Websockets; namespace Uwaa.HTTP; @@ -157,8 +158,24 @@ public sealed class HttpServer OnResponse?.Invoke(req, response); await req.Write(response).WaitAsync(Timeout); - if (response.StatusCode is not >= 200 or not < 300 || (req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close")) + + if (response is SwitchingProtocols swp) + { + //Increase timeouts + req.Client.SendTimeout = 120_000; + req.Client.ReceiveTimeout = 120_000; + + //Create and run websocket + Websocket ws = new Websocket(req, swp.ChosenProtocol); + CloseStatus closeStatus = await swp.Callback(ws); + ws.Close(closeStatus); break; //Close + } + else + { + if (response.StatusCode is not >= 200 or not < 300 || (req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close")) + break; //Close + } } catch (RequestParseException e) { diff --git a/HTTP/Responses/SwitchingProtocols.cs b/HTTP/Responses/SwitchingProtocols.cs index 9d6fd01..2609c84 100644 --- a/HTTP/Responses/SwitchingProtocols.cs +++ b/HTTP/Responses/SwitchingProtocols.cs @@ -1,15 +1,32 @@ -namespace Uwaa.HTTP.Responses; +using Uwaa.HTTP.Websockets; -internal class SwitchingProtocols : HttpResponse +namespace Uwaa.HTTP.Responses; + +/// +/// Upgrades the connection to a websocket. +/// +/// +/// After the connection is upgraded, a callback is called with the new websocket. +/// +public class SwitchingProtocols : HttpResponse { - public string AcceptKey; + internal readonly string AcceptKey; - public string? ChosenProtocol; + /// + /// The sub-protocol selected by protocol negotiation, if any. + /// + public readonly string? ChosenProtocol; - internal SwitchingProtocols(string acceptKey, string? chosenProtocol) : base(101, "Switching Protocols") + /// + /// Called once a HTTP request is upgrade to a websocket. + /// + public readonly WebsocketHandler Callback; + + internal SwitchingProtocols(string acceptKey, string? chosenProtocol, WebsocketHandler callback) : base(101, "Switching Protocols") { AcceptKey = acceptKey; ChosenProtocol = chosenProtocol; + Callback = callback; } protected override void WriteHeaders(WriteHeader writeHeader) @@ -22,3 +39,10 @@ internal class SwitchingProtocols : HttpResponse writeHeader("Sec-WebSocket-Protocol", ChosenProtocol); } } + +/// +/// A delegate called once a HTTP request is upgrade to a websocket. +/// +/// The websocket. +/// The status to send to the client when closing the websocket. +public delegate Task WebsocketHandler(Websocket ws); \ No newline at end of file diff --git a/HTTP/Websockets/Websocket.cs b/HTTP/Websockets/Websocket.cs index 1184238..7c32b9a 100644 --- a/HTTP/Websockets/Websocket.cs +++ b/HTTP/Websockets/Websocket.cs @@ -20,6 +20,7 @@ public sealed class Websocket public readonly string? SubProtocol; readonly List finalPayload = new List(); + WSOpcode currentOpcode; internal Websocket(HttpRequest client, string? subProtocol) { @@ -152,27 +153,31 @@ public sealed class Websocket } } - if (opcode is WSOpcode.Close) + switch (opcode) { - await Write(new DataFrame(WSOpcode.Close, true, Array.Empty())); - Request.Client.Close(); + case WSOpcode.Close: + await Write(new DataFrame(WSOpcode.Close, true, Array.Empty())); + Request.Client.Close(); + return new DataFrame(WSOpcode.Close, true, FlushPayload()); - return new DataFrame(WSOpcode.Close, true, FlushPayload()); - } - - if (opcode is WSOpcode.Text or WSOpcode.Binary) - { - finalPayload.AddRange(payload); - - if (fin) + case WSOpcode.Continuation: + case WSOpcode.Text: + case WSOpcode.Binary: { - return new DataFrame(opcode, fin, FlushPayload()); - } - } + if (opcode is WSOpcode.Text or WSOpcode.Binary) + currentOpcode = opcode; - if (opcode is WSOpcode.Ping) - { - await Write(new DataFrame(WSOpcode.Pong, true, payload)); + finalPayload.AddRange(payload); + + if (fin) + return new DataFrame(currentOpcode, fin, FlushPayload()); + else + break; + } + + case WSOpcode.Ping: + await Write(new DataFrame(WSOpcode.Pong, true, payload)); + break; } } finally @@ -194,14 +199,14 @@ public sealed class Websocket return final; } - public Task Write(bool endOfMessage, byte[] payload) + public Task Write(byte[] payload) { - return Write(new DataFrame(WSOpcode.Binary, endOfMessage, payload)); + return Write(new DataFrame(WSOpcode.Binary, true, payload)); } - public Task Write(bool endOfMessage, string payload) + public Task Write(string payload) { - return Write(new DataFrame(WSOpcode.Text, endOfMessage, Encoding.UTF8.GetBytes(payload))); + return Write(new DataFrame(WSOpcode.Text, true, Encoding.UTF8.GetBytes(payload))); } public async Task Write(DataFrame frame) @@ -252,7 +257,7 @@ public sealed class Websocket } } - public async void Close(CloseStatus status = CloseStatus.NormalClosure) + internal async void Close(CloseStatus status = CloseStatus.NormalClosure) { byte[] closeBuf = new byte[2]; BinaryPrimitives.WriteUInt16BigEndian(closeBuf, (ushort)status);