optimizations & cleanups

This commit is contained in:
uwaa 2024-11-14 10:26:19 +00:00
parent b12616420b
commit 17816628e8
7 changed files with 142 additions and 112 deletions

View file

@ -15,20 +15,10 @@ namespace Uwaa.HTTP;
/// </summary>
public sealed class HttpRequest
{
/// <summary>
/// The server the request is being made to.
/// </summary>
public readonly HttpServer Server;
/// <summary>
/// The TCP client serving this request.
/// </summary>
public readonly TcpClient Client;
/// <summary>
/// The underlying TCP stream.
/// </summary>
public readonly Stream Stream;
internal readonly TcpClient Client;
/// <summary>
/// The HTTP method of the request.
@ -60,14 +50,25 @@ public sealed class HttpRequest
/// </summary>
public readonly Dictionary<string, string> Headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
readonly BufferedStream Buffer;
/// <summary>
/// The body of the HTTP request, if any.
/// </summary>
public HttpContent? Content { get; internal set; }
/// <summary>
/// The underlying TCP stream.
/// </summary>
internal readonly Stream Stream;
/// <summary>
/// The read buffer.
/// </summary>
internal readonly BufferedStream Buffer;
readonly Decoder Decoder;
readonly StreamWriter Writer;
/// <summary>
/// If true, the request is connected and waiting for a reply.
/// If true, the request is connected.
/// </summary>
public bool Connected => Client.Connected;
@ -76,25 +77,16 @@ public sealed class HttpRequest
/// </summary>
public bool IsWebsocket => Headers.TryGetValue("Upgrade", out string? connection) && connection.Equals("websocket", StringComparison.OrdinalIgnoreCase);
/// <summary>
/// The total declared size of the request body, in bytes.
/// </summary>
public int ContentLength => Headers.TryGetValue("Content-Length", out string? contentLengthStr) && int.TryParse(contentLengthStr, out int contentLength) ? contentLength : 0;
internal HttpRequest(HttpServer server, TcpClient client, Stream stream, IPEndPoint endpoint)
internal HttpRequest(TcpClient client, Stream stream, IPEndPoint endpoint)
{
Server = server;
Client = client;
Stream = stream;
Endpoint = endpoint;
Buffer = new BufferedStream(stream);
Decoder = Encoding.UTF8.GetDecoder();
Writer = new StreamWriter(Buffer, Encoding.ASCII);
Writer.AutoFlush = true;
}
internal async Task ReadAllHeaders()
internal async Task ReadAll()
{
//Read initial header
string? header = await ReadLine() ?? throw new RequestParseException("Connection closed unexpectedly");
@ -140,6 +132,8 @@ public sealed class HttpRequest
}
ParseAccept();
Content = await ReadBody();
}
void ParseAccept()
@ -199,6 +193,35 @@ public sealed class HttpRequest
flush(); //Flush remaining
}
async Task<HttpContent?> ReadBody()
{
if (!Headers.TryGetValue("Content-Length", out string? contentLengthStr))
return null;
if (!Headers.TryGetValue("Content-Type", out string? contentTypeStr))
throw new RequestParseException("Content length was sent but no content type");
MIMEType contentType;
try
{
contentType = new MIMEType(contentTypeStr);
}
catch (FormatException e)
{
throw new RequestParseException(e.Message, e);
}
if (!int.TryParse(contentLengthStr, out int contentLength))
throw new RequestParseException("Invalid content length");
if (contentLength > 10_000_000)
throw new RequestParseException("Too much content (max: 10 MB)");
byte[] data = new byte[contentLength];
await Read(data);
return new HttpContent(contentType, data);
}
async Task<string?> ReadLine()
{
const int maxChars = 4096;
@ -236,64 +259,29 @@ public sealed class HttpRequest
}
}
public ValueTask<int> ReadBytes(Memory<byte> buffer)
internal ValueTask<int> Read(Memory<byte> buffer)
{
return Buffer.ReadAsync(buffer);
}
async Task WriteStatus(int code, string message)
{
await Writer.WriteAsync("HTTP/1.1 ");
await Writer.WriteAsync(code.ToString());
await Writer.WriteAsync(' ');
await Writer.WriteAsync(message);
await WriteLine();
}
async Task WriteHeader(string name, string value)
{
await Writer.WriteAsync(name);
await Writer.WriteAsync(": ");
await Writer.WriteAsync(value);
await WriteLine();
}
async Task WriteLine()
{
await Writer.WriteAsync("\r\n");
}
public ValueTask WriteBytes(ReadOnlyMemory<byte> bytes)
internal ValueTask Write(ReadOnlyMemory<byte> bytes)
{
return Buffer.WriteAsync(bytes);
}
public Task Flush()
internal async Task Flush()
{
Buffer.Flush();
Stream.Flush();
return Stream.FlushAsync();
await Buffer.FlushAsync();
await Stream.FlushAsync();
}
/// <summary>
/// Writes a response without closing the socket.
/// Writes a response.
/// </summary>
public async Task Write(HttpResponse response)
{
if (response.StatusCode == 0)
return;
await WriteStatus(response.StatusCode, response.StatusMessage);
foreach (var header in response.GetHeaders())
await WriteHeader(header.Item1, header.Item2);
await WriteHeader("Access-Control-Allow-Origin", "*");
await WriteLine();
if (response.Body.HasValue)
await WriteBytes(response.Body.Value.Content);
await Flush();
await response.Write(Stream);
await Stream.FlushAsync();
}
/// <summary>
@ -332,6 +320,7 @@ public sealed class HttpRequest
string acceptKey = Convert.ToBase64String(SHA1.HashData(Encoding.ASCII.GetBytes(wsKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")));
//Subprotocol negotiation
string? chosenProtocol = null;
string? requestedProtocols = Headers["Sec-WebSocket-Protocol"];
if (requestedProtocols != null && protocols != null && protocols.Length > 0)
@ -347,20 +336,12 @@ public sealed class HttpRequest
}
}
}
await Write(new BadRequest("Unsupported websocket subprotocol"));
await Write(new NotAcceptable("Unsupported websocket subprotocol"));
return null;
}
a:
await WriteStatus(101, "Switching Protocols");
await WriteHeader("Upgrade", "websocket");
await WriteHeader("Connection", "Upgrade");
await WriteHeader("Sec-WebSocket-Accept", acceptKey);
await WriteHeader("Access-Control-Allow-Origin", "*");
if (chosenProtocol != null)
await WriteHeader("Sec-WebSocket-Protocol", chosenProtocol);
await WriteLine();
await Flush();
await Write(new SwitchingProtocols(acceptKey, chosenProtocol));
return new Websocket(this, chosenProtocol);
}
@ -379,16 +360,6 @@ public sealed class HttpRequest
return false;
}
/// <summary>
/// Reads the entire request body. Only call this once.
/// </summary>
public async Task<string> ReadBody()
{
byte[] data = new byte[ContentLength];
int count = await ReadBytes(data);
return Encoding.UTF8.GetString(data, 0, count);
}
}
/// <summary>

View file

@ -1,4 +1,6 @@
namespace Uwaa.HTTP.Responses;
using System.Text;
namespace Uwaa.HTTP.Responses;
public class HttpResponse
{
@ -25,12 +27,43 @@ public class HttpResponse
Body = body;
}
public virtual IEnumerable<(string, string)> GetHeaders()
public async Task Write(Stream stream)
{
if (StatusCode == 0)
return;
StringBuilder sb = new StringBuilder();
void writeHeader(string name, string value)
{
sb.Append(name);
sb.Append(": ");
sb.Append(value);
sb.Append("\r\n");
}
sb.Append("HTTP/1.1 ");
sb.Append(StatusCode.ToString());
sb.Append(' ');
sb.Append(StatusMessage);
sb.Append("\r\n");
WriteHeaders(writeHeader);
sb.Append("\r\n");
await stream.WriteAsync(Encoding.ASCII.GetBytes(sb.ToString()));
if (Body.HasValue)
await stream.WriteAsync(Body.Value.Content);
}
protected virtual void WriteHeaders(WriteHeader writeHeader)
{
if (Body.HasValue)
{
yield return ("Content-Length", Body.Value.Content.Length.ToString());
yield return ("Content-Type", Body.Value.Type.ToString());
writeHeader("Content-Length", Body.Value.Content.Length.ToString());
writeHeader("Content-Type", Body.Value.Type.ToString());
}
}
}
public delegate void WriteHeader(string name, string value);

View file

@ -135,10 +135,10 @@ public sealed class HttpServer
//HTTP request-response loop
while (client.Connected)
{
HttpRequest req = new HttpRequest(this, client, stream, endpoint);
HttpRequest req = new HttpRequest(client, stream, endpoint);
try
{
await req.ReadAllHeaders().WaitAsync(Timeout);
await req.ReadAll().WaitAsync(Timeout);
//Parse path
ArraySegment<string> pathSpl = req.Path.Split('/', StringSplitOptions.RemoveEmptyEntries);

View file

@ -5,15 +5,16 @@
/// </summary>
public class Redirect : HttpResponse
{
public string Location;
public readonly string Location;
public Redirect(string location) : base(301, "Redirect")
{
Location = location;
}
public override IEnumerable<(string, string)> GetHeaders()
protected override void WriteHeaders(WriteHeader writeHeader)
{
yield return ("Location", Location);
base.WriteHeaders(writeHeader);
writeHeader("Location", Location);
}
}

View file

@ -0,0 +1,24 @@
namespace Uwaa.HTTP.Responses;
internal class SwitchingProtocols : HttpResponse
{
public string AcceptKey;
public string? ChosenProtocol;
internal SwitchingProtocols(string acceptKey, string? chosenProtocol) : base(101, "Switching Protocols")
{
AcceptKey = acceptKey;
ChosenProtocol = chosenProtocol;
}
protected override void WriteHeaders(WriteHeader writeHeader)
{
base.WriteHeaders(writeHeader);
writeHeader("Upgrade", "websocket");
writeHeader("Connection", "Upgrade");
writeHeader("Sec-WebSocket-Accept", AcceptKey);
if (ChosenProtocol != null)
writeHeader("Sec-WebSocket-Protocol", ChosenProtocol);
}
}

View file

@ -23,10 +23,11 @@ class PreflightResponse : OK
{
}
public override IEnumerable<(string, string)> GetHeaders()
protected override void WriteHeaders(WriteHeader writeHeader)
{
yield return ("Access-Control-Allow-Origin", "*");
yield return ("Methods", "GET,HEAD,POST,PUT,DELETE,CONNECT,OPTIONS,TRACE,PATCH");
yield return ("Vary", "Origin");
base.WriteHeaders(writeHeader);
writeHeader("Access-Control-Allow-Origin", "*");
writeHeader("Methods", "GET,HEAD,POST,PUT,DELETE,CONNECT,OPTIONS,TRACE,PATCH");
writeHeader("Vary", "Origin");
}
}

View file

@ -42,7 +42,7 @@ public sealed class Websocket
while (true)
{
//First byte
if (await Request.ReadBytes(recvBuffer.AsMemory(0, 2)) < 2)
if (await Request.Read(recvBuffer.AsMemory(0, 2)) < 2)
throw new EndOfStreamException();
byte firstByte = recvBuffer[0];
@ -70,14 +70,14 @@ public sealed class Websocket
{
if (secondByte == 126)
{
if (await Request.ReadBytes(recvBuffer.AsMemory(0, 2)) < 2)
if (await Request.Read(recvBuffer.AsMemory(0, 2)) < 2)
throw new EndOfStreamException();
payloadLength = BinaryPrimitives.ReadUInt16BigEndian(recvBuffer);
}
else if (secondByte == 127)
{
if (await Request.ReadBytes(recvBuffer.AsMemory(0, 8)) < 8)
if (await Request.Read(recvBuffer.AsMemory(0, 8)) < 8)
throw new EndOfStreamException();
payloadLength = BinaryPrimitives.ReadUInt32BigEndian(recvBuffer);
@ -95,7 +95,7 @@ public sealed class Websocket
byte maskKey1, maskKey2, maskKey3, maskKey4;
if (maskEnabled)
{
if (await Request.ReadBytes(recvBuffer.AsMemory(0, 4)) < 4)
if (await Request.Read(recvBuffer.AsMemory(0, 4)) < 4)
throw new EndOfStreamException();
maskKey1 = recvBuffer[0];
@ -116,7 +116,7 @@ public sealed class Websocket
try
{
ArraySegment<byte> payload = new ArraySegment<byte>(payloadBuffer, 0, (int)payloadLength);
if (await Request.ReadBytes(payload) < payloadLength)
if (await Request.Read(payload) < payloadLength)
throw new EndOfStreamException();
//Unmask payload
@ -220,12 +220,12 @@ public sealed class Websocket
firstByte |= (byte)((int)frame.Opcode & 0b00001111);
writeBuf[0] = firstByte;
await Request.WriteBytes(writeBuf.AsMemory(0, 1));
await Request.Write(writeBuf.AsMemory(0, 1));
if (frame.Payload.Count < 126)
{
writeBuf[0] = (byte)frame.Payload.Count;
await Request.WriteBytes(writeBuf.AsMemory(0, 1));
await Request.Write(writeBuf.AsMemory(0, 1));
}
else
{
@ -233,17 +233,17 @@ public sealed class Websocket
{
writeBuf[0] = 126;
BinaryPrimitives.WriteUInt16BigEndian(writeBuf.AsSpan(1), (ushort)frame.Payload.Count);
await Request.WriteBytes(writeBuf.AsMemory(0, 3));
await Request.Write(writeBuf.AsMemory(0, 3));
}
else
{
writeBuf[0] = 127;
BinaryPrimitives.WriteUInt64BigEndian(writeBuf.AsSpan(1), (ulong)frame.Payload.Count);
await Request.WriteBytes(writeBuf.AsMemory(0, 9));
await Request.Write(writeBuf.AsMemory(0, 9));
}
}
await Request.WriteBytes(frame.Payload);
await Request.Write(frame.Payload);
await Request.Flush();
}
finally