cleanups & changes to websockets

This commit is contained in:
uwaa 2024-11-15 15:16:58 +00:00
parent 17816628e8
commit 553f963516
4 changed files with 90 additions and 49 deletions

View file

@ -278,7 +278,7 @@ public sealed class HttpRequest
/// <summary>
/// Writes a response.
/// </summary>
public async Task Write(HttpResponse response)
internal async Task Write(HttpResponse response)
{
await response.Write(Stream);
await Stream.FlushAsync();
@ -287,7 +287,7 @@ public sealed class HttpRequest
/// <summary>
/// Attempts to destroy the connection, ignoring all errors.
/// </summary>
public async void Close(HttpResponse? response)
internal async void Close(HttpResponse? response)
{
if (response != null && Client.Connected)
{
@ -304,46 +304,41 @@ public sealed class HttpRequest
}
/// <summary>
/// Sends a "switching protocol" header for a websocket.
/// Attempts to accept and upgrade a connection to a websocket.
/// </summary>
public async Task<Websocket?> UpgradeToWebsocket(params string[]? protocols)
/// <param name="callback">The websocket execution function to call.</param>
/// <param name="protocols">Subprotocols which can be accepted. If null or empty, any protocol will be accepted.</param>
/// <returns>Returns a <seealso cref="SwitchingProtocols"/> response.</returns>
/// <remarks>
/// If an upgrade has not been requested or no subprotocol can be negotiated, this will return null.
/// </remarks>
public SwitchingProtocols? UpgradeToWebsocket(WebsocketHandler callback, params string[]? protocols)
{
if (!Headers.TryGetValue("Sec-WebSocket-Key", out string? wsKey))
{
await Write(new BadRequest("Missing Sec-WebSocket-Key header"));
return null;
}
//Increase timeouts
Client.SendTimeout = 120_000;
Client.ReceiveTimeout = 120_000;
string acceptKey = Convert.ToBase64String(SHA1.HashData(Encoding.ASCII.GetBytes(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")));
//Subprotocol negotiation
string? chosenProtocol = null;
string? requestedProtocols = Headers["Sec-WebSocket-Protocol"];
if (requestedProtocols != null && protocols != null && protocols.Length > 0)
{
foreach (string supported in protocols)
foreach (string requested in requestedProtocols.ToLower().Split(',', StringSplitOptions.TrimEntries))
{
foreach (string requested in requestedProtocols.ToLower().Split(',', StringSplitOptions.TrimEntries))
foreach (string supported in protocols)
{
if (requested == supported)
if (requested.Equals(supported, StringComparison.InvariantCultureIgnoreCase))
{
chosenProtocol = supported;
goto a;
}
}
}
await Write(new NotAcceptable("Unsupported websocket subprotocol"));
return null;
}
a:
await Write(new SwitchingProtocols(acceptKey, chosenProtocol));
return new Websocket(this, chosenProtocol);
a:
string acceptKey = Convert.ToBase64String(SHA1.HashData(Encoding.ASCII.GetBytes(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")));
return new SwitchingProtocols(acceptKey, chosenProtocol, callback);
}
/// <summary>

View file

@ -4,6 +4,7 @@ using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Uwaa.HTTP.Responses;
using Uwaa.HTTP.Routing;
using Uwaa.HTTP.Websockets;
namespace Uwaa.HTTP;
@ -157,8 +158,24 @@ public sealed class HttpServer
OnResponse?.Invoke(req, response);
await req.Write(response).WaitAsync(Timeout);
if (response.StatusCode is not >= 200 or not < 300 || (req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close"))
if (response is SwitchingProtocols swp)
{
//Increase timeouts
req.Client.SendTimeout = 120_000;
req.Client.ReceiveTimeout = 120_000;
//Create and run websocket
Websocket ws = new Websocket(req, swp.ChosenProtocol);
CloseStatus closeStatus = await swp.Callback(ws);
ws.Close(closeStatus);
break; //Close
}
else
{
if (response.StatusCode is not >= 200 or not < 300 || (req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close"))
break; //Close
}
}
catch (RequestParseException e)
{

View file

@ -1,15 +1,32 @@
namespace Uwaa.HTTP.Responses;
using Uwaa.HTTP.Websockets;
internal class SwitchingProtocols : HttpResponse
namespace Uwaa.HTTP.Responses;
/// <summary>
/// Upgrades the connection to a websocket.
/// </summary>
/// <remarks>
/// After the connection is upgraded, a callback is called with the new websocket.
/// </remarks>
public class SwitchingProtocols : HttpResponse
{
public string AcceptKey;
internal readonly string AcceptKey;
public string? ChosenProtocol;
/// <summary>
/// The sub-protocol selected by protocol negotiation, if any.
/// </summary>
public readonly string? ChosenProtocol;
internal SwitchingProtocols(string acceptKey, string? chosenProtocol) : base(101, "Switching Protocols")
/// <summary>
/// Called once a HTTP request is upgrade to a websocket.
/// </summary>
public readonly WebsocketHandler Callback;
internal SwitchingProtocols(string acceptKey, string? chosenProtocol, WebsocketHandler callback) : base(101, "Switching Protocols")
{
AcceptKey = acceptKey;
ChosenProtocol = chosenProtocol;
Callback = callback;
}
protected override void WriteHeaders(WriteHeader writeHeader)
@ -22,3 +39,10 @@ internal class SwitchingProtocols : HttpResponse
writeHeader("Sec-WebSocket-Protocol", ChosenProtocol);
}
}
/// <summary>
/// A delegate called once a HTTP request is upgrade to a websocket.
/// </summary>
/// <param name="ws">The websocket.</param>
/// <returns>The status to send to the client when closing the websocket.</returns>
public delegate Task<CloseStatus> WebsocketHandler(Websocket ws);

View file

@ -20,6 +20,7 @@ public sealed class Websocket
public readonly string? SubProtocol;
readonly List<byte> finalPayload = new List<byte>();
WSOpcode currentOpcode;
internal Websocket(HttpRequest client, string? subProtocol)
{
@ -152,27 +153,31 @@ public sealed class Websocket
}
}
if (opcode is WSOpcode.Close)
switch (opcode)
{
await Write(new DataFrame(WSOpcode.Close, true, Array.Empty<byte>()));
Request.Client.Close();
case WSOpcode.Close:
await Write(new DataFrame(WSOpcode.Close, true, Array.Empty<byte>()));
Request.Client.Close();
return new DataFrame(WSOpcode.Close, true, FlushPayload());
return new DataFrame(WSOpcode.Close, true, FlushPayload());
}
if (opcode is WSOpcode.Text or WSOpcode.Binary)
{
finalPayload.AddRange(payload);
if (fin)
case WSOpcode.Continuation:
case WSOpcode.Text:
case WSOpcode.Binary:
{
return new DataFrame(opcode, fin, FlushPayload());
}
}
if (opcode is WSOpcode.Text or WSOpcode.Binary)
currentOpcode = opcode;
if (opcode is WSOpcode.Ping)
{
await Write(new DataFrame(WSOpcode.Pong, true, payload));
finalPayload.AddRange(payload);
if (fin)
return new DataFrame(currentOpcode, fin, FlushPayload());
else
break;
}
case WSOpcode.Ping:
await Write(new DataFrame(WSOpcode.Pong, true, payload));
break;
}
}
finally
@ -194,14 +199,14 @@ public sealed class Websocket
return final;
}
public Task Write(bool endOfMessage, byte[] payload)
public Task Write(byte[] payload)
{
return Write(new DataFrame(WSOpcode.Binary, endOfMessage, payload));
return Write(new DataFrame(WSOpcode.Binary, true, payload));
}
public Task Write(bool endOfMessage, string payload)
public Task Write(string payload)
{
return Write(new DataFrame(WSOpcode.Text, endOfMessage, Encoding.UTF8.GetBytes(payload)));
return Write(new DataFrame(WSOpcode.Text, true, Encoding.UTF8.GetBytes(payload)));
}
public async Task Write(DataFrame frame)
@ -252,7 +257,7 @@ public sealed class Websocket
}
}
public async void Close(CloseStatus status = CloseStatus.NormalClosure)
internal async void Close(CloseStatus status = CloseStatus.NormalClosure)
{
byte[] closeBuf = new byte[2];
BinaryPrimitives.WriteUInt16BigEndian(closeBuf, (ushort)status);