cleanups and fixes

This commit is contained in:
uwaa 2024-11-04 14:56:35 +00:00
parent d48ce6742a
commit d94fa95162
2 changed files with 103 additions and 72 deletions

View file

@ -1,4 +1,6 @@
using System.Net.Sockets;
using System.Net;
using System.Buffers;
using System.Security.Cryptography;
using System.Text;
using System.Web;
@ -46,7 +48,12 @@ public sealed class HttpRequest
/// <summary>
/// Content MIME types which the request wants. If empty, the request does not care.
/// </summary>
public MIMEType[] Accept = [];
public MIMEType[] Accept { get; private set; } = [];
/// <summary>
/// The IP address and port of the requester.
/// </summary>
public IPEndPoint Endpoint { get; internal set; }
/// <summary>
/// HTTP headers included in the request.
@ -69,11 +76,12 @@ public sealed class HttpRequest
/// </summary>
public bool IsWebsocket => Headers.TryGetValue("Upgrade", out string? connection) && connection.Equals("websocket", StringComparison.OrdinalIgnoreCase);
internal HttpRequest(HttpServer server, TcpClient client, Stream stream)
internal HttpRequest(HttpServer server, TcpClient client, Stream stream, IPEndPoint endpoint)
{
Server = server;
Client = client;
Stream = stream;
Endpoint = endpoint;
Buffer = new BufferedStream(stream);
Decoder = Encoding.UTF8.GetDecoder();
@ -84,24 +92,20 @@ public sealed class HttpRequest
internal async Task ReadAllHeaders()
{
//Read initial header
string? header = await ReadLine();
if (header == null)
throw new RequestParseException("Connection closed unexpectedly");
if (header.Length > 1000)
throw new RequestParseException("Initial header is too long");
string? header = await ReadLine() ?? throw new RequestParseException("Connection closed unexpectedly");
{
string[] parts = header.Split(' ', 3, StringSplitOptions.RemoveEmptyEntries);
if (parts.Length != 3)
throw new RequestParseException("Invalid initial header");
//Method
HttpMethod method;
if (!Enum.TryParse(parts[0], true, out method))
throw new RequestParseException("Unknown HTTP method");
Method = method;
string[] pathParts = parts[1].Replace("\\", "/").Split('?', 2, StringSplitOptions.RemoveEmptyEntries);
Path = pathParts[0];
Query = HttpUtility.ParseQueryString(pathParts.Length > 1 ? pathParts[1] : string.Empty);
@ -110,19 +114,14 @@ public sealed class HttpRequest
//Read headers
while (true)
{
string? headerStr = await ReadLine();
if (headerStr == null)
throw new RequestParseException("Connection closed unexpectedly");
string? headerStr = await ReadLine() ?? throw new RequestParseException("Connection closed unexpectedly");
if (string.IsNullOrWhiteSpace(headerStr))
break; //End of headers
if (Headers.Count >= 20)
if (Headers.Count >= 30)
throw new RequestParseException("Too many headers");
if (headerStr.Length > 500)
throw new RequestParseException("A request header is too long");
int splitPoint = headerStr.IndexOf(':');
if (splitPoint == -1)
throw new RequestParseException("A header is invalid");
@ -153,7 +152,17 @@ public sealed class HttpRequest
int resultIndex = 0;
int splStart = 0;
int splEnd = 0;
void flush() => Accept[resultIndex++] = new MIMEType(accept.AsSpan(splStart..(splEnd + 1)));
void flush()
{
try
{
Accept[resultIndex++] = new MIMEType(accept.AsSpan(splStart..(splEnd + 1)));
}
catch (FormatException e)
{
throw new RequestParseException(e.Message, e);
}
}
for (int i = 0; i < accept.Length; i++)
{
switch (accept[i])
@ -186,29 +195,41 @@ public sealed class HttpRequest
flush(); //Flush remaining
}
async Task<string> ReadLine()
async Task<string?> ReadLine()
{
const int maxChars = 4096;
byte[] dataBuffer = new byte[1];
char[] charBuffer = new char[4096];
int charBufferIndex = 0;
while (true)
char[] charBuffer = ArrayPool<char>.Shared.Rent(maxChars);
try
{
if (await Buffer.ReadAsync(dataBuffer) == 0)
break;
if (charBufferIndex >= charBuffer.Length)
throw new RequestParseException("Header is too large");
charBufferIndex += Decoder.GetChars(dataBuffer, 0, 1, charBuffer, charBufferIndex, false);
if (charBufferIndex >= 2 && charBuffer[charBufferIndex - 1] == '\n' && charBuffer[charBufferIndex - 2] == '\r')
int charBufferIndex = 0;
while (true)
{
charBufferIndex -= 2;
break;
if (await Buffer.ReadAsync(dataBuffer) == 0)
if (charBufferIndex == 0)
return null;
else
break;
if (charBufferIndex >= maxChars)
throw new RequestParseException("Header is too large");
charBufferIndex += Decoder.GetChars(dataBuffer, 0, 1, charBuffer, charBufferIndex, false);
if (charBufferIndex >= 2 && charBuffer[charBufferIndex - 1] == '\n' && charBuffer[charBufferIndex - 2] == '\r')
{
charBufferIndex -= 2;
break;
}
}
Decoder.Reset();
return new string(charBuffer, 0, charBufferIndex);
}
finally
{
//Clearing the array is unnecessary but it is good security just in case.
ArrayPool<char>.Shared.Return(charBuffer, true);
}
Decoder.Reset();
return new string(charBuffer, 0, charBufferIndex);
}
@ -361,7 +382,7 @@ public sealed class HttpRequest
/// </summary>
class RequestParseException : IOException
{
public RequestParseException(string? message) : base(message)
public RequestParseException(string? message, Exception? innerException = null) : base(message, innerException)
{
}
}

View file

@ -35,6 +35,26 @@ public sealed class HttpServer
/// </summary>
public int Port => ((IPEndPoint)listener.LocalEndpoint).Port;
/// <summary>
/// Called when a client establishes a connection with the server.
/// </summary>
public event Action<TcpClient>? OnConnectionBegin;
/// <summary>
/// Called when a connection has terminated.
/// </summary>
public event Action<IPEndPoint>? OnConnectionEnd;
/// <summary>
/// Called when a request has been served a response.
/// </summary>
public event Action<HttpRequest, HttpResponse>? OnResponse;
/// <summary>
/// The maximum time the socket may be inactive before it is presumed dead and closed.
/// </summary>
public TimeSpan Timeout { get; set; } = TimeSpan.FromSeconds(20);
readonly Dictionary<IPAddress, int> IPCounts = new Dictionary<IPAddress, int>();
readonly SemaphoreSlim IPCountsLock = new SemaphoreSlim(1, 1);
@ -67,6 +87,8 @@ public sealed class HttpServer
async void HandleClient(TcpClient client)
{
OnConnectionBegin?.Invoke(client);
if (client.Client.RemoteEndPoint is not IPEndPoint endpoint)
return;
@ -97,32 +119,26 @@ public sealed class HttpServer
try
{
//Setup client
client.Client.LingerState = new LingerOption(true, 5);
client.Client.SendTimeout = 20_000;
client.Client.ReceiveTimeout = 20_000;
Stream stream;
try
{
stream = client.GetStream();
//Setup client
Stream stream = client.GetStream();
if (Certificate != null)
{
//Pass through SSL stream
SslStream ssl = new SslStream(stream);
await ssl.AuthenticateAsServerAsync(Certificate);
await ssl.AuthenticateAsServerAsync(Certificate).WaitAsync(Timeout);
stream = ssl;
}
//HTTP request-response loop
while (client.Connected)
{
HttpRequest req = new HttpRequest(this, client, stream);
HttpResponse? response;
bool keepAlive = true;
HttpRequest req = new HttpRequest(this, client, stream, endpoint);
try
{
await req.ReadAllHeaders();
await req.ReadAllHeaders().WaitAsync(Timeout);
//Parse path
ArraySegment<string> pathSpl = req.Path.Split('/', StringSplitOptions.RemoveEmptyEntries);
@ -137,47 +153,39 @@ public sealed class HttpServer
}
//Execute
response = await Router.GetResponse(req, pathSpl);
if (response != null)
{
await req.Write(response);
keepAlive = (response.StatusCode is >= 200 and < 300) && !(req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close");
}
else
{
await req.Write(new HttpResponse(404, "Router produced no response"));
keepAlive = false;
}
HttpResponse? response = (await Router.GetResponse(req, pathSpl)) ?? new HttpResponse(404, "Router produced no response");
OnResponse?.Invoke(req, response);
await req.Write(response).WaitAsync(Timeout);
if (response.StatusCode is not >= 200 or not < 300 || (req.Headers.TryGetValue("connection", out string? connectionValue) && connectionValue == "close"))
break; //Close
}
catch (RequestParseException e)
{
req.Close(new HttpResponse(400, e.Message));
break;
}
catch (TimeoutException)
{
//Timeout
break;
}
catch (IOException)
{
//Remote disconnect
throw;
}
catch (Exception)
{
await req.Write(new HttpResponse(500, "Internal Server Error"));
await req.Write(new HttpResponse(500, "Internal Server Error")).WaitAsync(Timeout);
throw;
}
if (!keepAlive)
{
client.Close();
break;
}
}
}
catch (IOException)
{
//Client likely disconnected unexpectedly
}
catch (Exception)
{
//Error
//Swallow exceptions to prevent the server from crashing.
//When debugging, use a debugger to break on exceptions.
}
}
finally
@ -199,6 +207,8 @@ public sealed class HttpServer
{
IPCountsLock.Release();
}
OnConnectionEnd?.Invoke(endpoint);
}
}
}