diff --git a/Hamakaze/Headers/HttpConnectionHeader.cs b/Hamakaze/Headers/HttpConnectionHeader.cs index 50b1318..ff406ec 100644 --- a/Hamakaze/Headers/HttpConnectionHeader.cs +++ b/Hamakaze/Headers/HttpConnectionHeader.cs @@ -9,9 +9,10 @@ namespace Hamakaze.Headers { public const string CLOSE = @"close"; public const string KEEP_ALIVE = @"keep-alive"; + public const string UPGRADE = @"upgrade"; public HttpConnectionHeader(string mode) { - Value = mode ?? throw new ArgumentNullException(nameof(mode)); + Value = (mode ?? throw new ArgumentNullException(nameof(mode))).ToLowerInvariant(); } } } diff --git a/Hamakaze/HttpClient.cs b/Hamakaze/HttpClient.cs index be009b5..24f05d3 100644 --- a/Hamakaze/HttpClient.cs +++ b/Hamakaze/HttpClient.cs @@ -1,14 +1,22 @@ using Hamakaze.Headers; +using Hamakaze.WebSocket; using System; using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography; +using System.Text; namespace Hamakaze { public class HttpClient : IDisposable { public const string PRODUCT_STRING = @"HMKZ"; public const string VERSION_MAJOR = @"1"; - public const string VERSION_MINOR = @"0"; + public const string VERSION_MINOR = @"1"; public const string USER_AGENT = PRODUCT_STRING + @"/" + VERSION_MAJOR + @"." + VERSION_MINOR; + private const string WS_GUID = @"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + private const string WS_PROTO = @"websocket"; + private const int WS_RNG = 16; + private static HttpClient InstanceValue { get; set; } public static HttpClient Instance { get { @@ -47,7 +55,8 @@ namespace Hamakaze { request.UserAgent = DefaultUserAgent; if(!request.HasHeader(HttpAcceptEncodingHeader.NAME)) request.AcceptedEncodings = AcceptedEncodings; - request.Connection = ReuseConnections ? HttpConnectionHeader.KEEP_ALIVE : HttpConnectionHeader.CLOSE; + if(!request.HasHeader(HttpConnectionHeader.NAME)) + request.Connection = ReuseConnections ? HttpConnectionHeader.KEEP_ALIVE : HttpConnectionHeader.CLOSE; HttpTask task = new(Connections, request, disposeRequest, disposeResponse); @@ -85,6 +94,84 @@ namespace Hamakaze { RunTask(CreateTask(request, onComplete, onError, onCancel, onDownloadProgress, onUploadProgress, onStateChange, disposeRequest, disposeResponse)); } + public void CreateWsClient( + string url, + Action onOpen, + Action onMessage, + Action onError, + IEnumerable protocols = null + ) { + CreateWsConnection( + url, + conn => onOpen(new WsClient(conn, onMessage, onError)), + onError, + protocols + ); + } + + public void CreateWsConnection( + string url, + Action onOpen, + Action onError, + IEnumerable protocols = null + ) { + string key = Convert.ToBase64String(RandomNumberGenerator.GetBytes(WS_RNG)); + + HttpRequestMessage req = new HttpRequestMessage(@"GET", url); + req.Connection = HttpConnectionHeader.UPGRADE; + req.SetHeader(@"Cache-Control", @"no-cache"); + req.SetHeader(@"Upgrade", WS_PROTO); + req.SetHeader(@"Sec-WebSocket-Key", key); + req.SetHeader(@"Sec-WebSocket-Version", @"13"); + + if(protocols?.Any() == true) + req.SetHeader(@"Sec-WebSocket-Protocol", string.Join(@", ", protocols)); + + SendRequest( + req, + (t, res) => { + try { + if(res.ProtocolVersion.CompareTo(@"1.1") < 0) + throw new HttpUpgradeProtocolVersionException(@"1.1", res.ProtocolVersion); + + if(res.StatusCode != 101) + throw new HttpUpgradeUnexpectedStatusException(res.StatusCode); + + if(res.Connection != HttpConnectionHeader.UPGRADE) + throw new HttpUpgradeUnexpectedHeaderException( + @"Connection", + HttpConnectionHeader.UPGRADE, + res.Connection + ); + + string hUpgrade = res.GetHeaderLine(@"Upgrade"); + if(hUpgrade != WS_PROTO) + throw new HttpUpgradeUnexpectedHeaderException(@"Upgrade", WS_PROTO, hUpgrade); + + string serverHashStr = res.GetHeaderLine(@"Sec-WebSocket-Accept"); + byte[] expectHash = SHA1.HashData(Encoding.ASCII.GetBytes(key + WS_GUID)); + + if(string.IsNullOrWhiteSpace(serverHashStr)) + throw new HttpUpgradeUnexpectedHeaderException( + @"Sec-WebSocket-Accept", + Convert.ToBase64String(expectHash), + serverHashStr + ); + + byte[] givenHash = Convert.FromBase64String(serverHashStr.Trim()); + + if(!expectHash.SequenceEqual(givenHash)) + throw new HttpUpgradeInvalidHashException(Convert.ToBase64String(expectHash), serverHashStr); + + onOpen(t.Connection.ToWebSocket()); + } catch(Exception ex) { + onError(ex); + } + }, + (t, ex) => onError(ex) + ); + } + public static void Send( HttpRequestMessage request, Action onComplete = null, diff --git a/Hamakaze/HttpConnection.cs b/Hamakaze/HttpConnection.cs index d5b2c9e..509a6b6 100644 --- a/Hamakaze/HttpConnection.cs +++ b/Hamakaze/HttpConnection.cs @@ -4,15 +4,14 @@ using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; +using Hamakaze.WebSocket; namespace Hamakaze { public class HttpConnection : IDisposable { public IPEndPoint EndPoint { get; } public Stream Stream { get; } - public Socket Socket { get; } - public NetworkStream NetworkStream { get; } - public SslStream SslStream { get; } + private Socket Socket { get; } public string Host { get; } public bool IsSecure { get; } @@ -24,6 +23,7 @@ namespace Hamakaze { public DateTimeOffset LastOperation { get; private set; } = DateTimeOffset.Now; public bool InUse { get; private set; } + public bool HasUpgraded { get; private set; } public HttpConnection(string host, IPEndPoint endPoint, bool secure) { Host = host ?? throw new ArgumentNullException(nameof(host)); @@ -39,19 +39,19 @@ namespace Hamakaze { }; Socket.Connect(endPoint); - NetworkStream = new NetworkStream(Socket, true); + Stream stream = new NetworkStream(Socket, true); if(IsSecure) { - SslStream = new SslStream(NetworkStream, false, (s, ce, ch, e) => e == SslPolicyErrors.None, null); - Stream = SslStream; - SslStream.AuthenticateAsClient( + SslStream sslStream = new SslStream(stream, false, (s, ce, ch, e) => e == SslPolicyErrors.None, null); + Stream = sslStream; + sslStream.AuthenticateAsClient( Host, null, SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13, true ); } else - Stream = NetworkStream; + Stream = stream; } public void MarkUsed() { @@ -61,13 +61,21 @@ namespace Hamakaze { } public bool Acquire() { - return !InUse && (InUse = true); + return !HasUpgraded && !InUse && (InUse = true); } public void Release() { InUse = false; } + public WsConnection ToWebSocket() { + if(HasUpgraded) + throw new HttpConnectionAlreadyUpgradedException(); + HasUpgraded = true; + + return new WsConnection(Stream); + } + private bool IsDisposed; ~HttpConnection() => DoDispose(); @@ -79,7 +87,9 @@ namespace Hamakaze { if(IsDisposed) return; IsDisposed = true; - Stream.Dispose(); + + if(!HasUpgraded) + Stream.Dispose(); } } } diff --git a/Hamakaze/HttpException.cs b/Hamakaze/HttpException.cs index 6219e57..726b9ab 100644 --- a/Hamakaze/HttpException.cs +++ b/Hamakaze/HttpException.cs @@ -5,6 +5,32 @@ namespace Hamakaze { public HttpException(string message) : base(message) { } } + public class HttpUpgradeException : HttpException { + public HttpUpgradeException(string message) : base(message) { } + } + public class HttpUpgradeProtocolVersionException : HttpUpgradeException { + public HttpUpgradeProtocolVersionException(string expectedVersion, string givenVersion) + : base($@"Server HTTP version ({givenVersion}) is lower than what is expected {expectedVersion}.") { } + } + public class HttpUpgradeUnexpectedStatusException : HttpUpgradeException { + public HttpUpgradeUnexpectedStatusException(int statusCode) : base($@"Expected HTTP status code 101, got {statusCode}.") { } + } + public class HttpUpgradeUnexpectedHeaderException : HttpUpgradeException { + public HttpUpgradeUnexpectedHeaderException(string header, string expected, string given) + : base($@"Unexpected {header} header value ""{given}"", expected ""{expected}"".") { } + } + public class HttpUpgradeInvalidHashException : HttpUpgradeException { + public HttpUpgradeInvalidHashException(string expected, string given) + : base($@"Server sent invalid hash ""{given}"", expected ""{expected}"".") { } + } + + public class HttpConnectionException : HttpException { + public HttpConnectionException(string message) : base(message) { } + } + public class HttpConnectionAlreadyUpgradedException : HttpConnectionException { + public HttpConnectionAlreadyUpgradedException() : base(@"This connection has already been upgraded.") { } + } + public class HttpConnectionManagerException : HttpException { public HttpConnectionManagerException(string message) : base(message) { } } diff --git a/Hamakaze/HttpRequestMessage.cs b/Hamakaze/HttpRequestMessage.cs index 77a7907..ba01492 100644 --- a/Hamakaze/HttpRequestMessage.cs +++ b/Hamakaze/HttpRequestMessage.cs @@ -92,7 +92,8 @@ namespace Hamakaze { public HttpRequestMessage(string method, Uri uri) { Method = method ?? throw new ArgumentNullException(nameof(method)); RequestTarget = uri.PathAndQuery; - IsSecure = uri.Scheme.Equals(@"https", StringComparison.InvariantCultureIgnoreCase); + IsSecure = uri.Scheme.Equals(@"https", StringComparison.InvariantCultureIgnoreCase) + || uri.Scheme.Equals(@"wss", StringComparison.InvariantCultureIgnoreCase); Host = uri.Host; ushort defaultPort = (IsSecure ? HTTPS : HTTP); Port = uri.Port == -1 ? defaultPort : (ushort)uri.Port; diff --git a/Hamakaze/HttpResponseMessage.cs b/Hamakaze/HttpResponseMessage.cs index c401ed1..b2d7abb 100644 --- a/Hamakaze/HttpResponseMessage.cs +++ b/Hamakaze/HttpResponseMessage.cs @@ -124,7 +124,7 @@ namespace Hamakaze { using MemoryStream ms = new(); int byt; ushort lastTwo = 0; - for(; ; ) { + for(;;) { byt = stream.ReadByte(); if(byt == -1 && ms.Length == 0) throw new IOException(@"readLine: There is no data."); @@ -238,11 +238,9 @@ namespace Hamakaze { readBuffer(chunkLength); readLine(); } - readLine(); } else if(contentLength != 0) { body = new MemoryStream(); readBuffer(contentLength); - readLine(); } if(body != null) diff --git a/Hamakaze/HttpTask.cs b/Hamakaze/HttpTask.cs index e3ae06e..5bd604f 100644 --- a/Hamakaze/HttpTask.cs +++ b/Hamakaze/HttpTask.cs @@ -1,7 +1,6 @@ using Hamakaze.Headers; using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Net; @@ -25,7 +24,7 @@ namespace Hamakaze { private HttpConnectionManager Connections { get; } private IEnumerable Addresses { get; set; } - private HttpConnection Connection { get; set; } + public HttpConnection Connection { get; private set; } public bool DisposeRequest { get; set; } public bool DisposeResponse { get; set; } @@ -70,103 +69,90 @@ namespace Hamakaze { if(IsCancelled) return false; - switch(State) { - case TaskState.Initial: - State = TaskState.Lookup; - OnStateChange?.Invoke(this, State); - DoLookup(); - break; - case TaskState.Lookup: - State = TaskState.Request; - OnStateChange?.Invoke(this, State); - DoRequest(); - break; - case TaskState.Request: - State = TaskState.Response; - OnStateChange?.Invoke(this, State); - DoResponse(); - break; - case TaskState.Response: - State = TaskState.Finished; - OnStateChange?.Invoke(this, State); - OnComplete?.Invoke(this, Response); - if(DisposeResponse) - Response?.Dispose(); - if(DisposeRequest) - Request?.Dispose(); - return false; - default: - Error(new HttpTaskInvalidStateException()); - return false; + try { + switch(State) { + case TaskState.Initial: + State = TaskState.Lookup; + OnStateChange?.Invoke(this, State); + DoLookup(); + break; + case TaskState.Lookup: + State = TaskState.Request; + OnStateChange?.Invoke(this, State); + DoRequest(); + break; + case TaskState.Request: + State = TaskState.Response; + OnStateChange?.Invoke(this, State); + DoResponse(); + break; + case TaskState.Response: + State = TaskState.Finished; + OnStateChange?.Invoke(this, State); + OnComplete?.Invoke(this, Response); + if(DisposeResponse) + Response?.Dispose(); + if(DisposeRequest) + Request?.Dispose(); + return false; + default: + throw new HttpTaskInvalidStateException(); + } + } catch(Exception ex) { + Error(ex); + return false; } return true; } private void DoLookup() { - try { - Addresses = Dns.GetHostAddresses(Request.Host); - } catch(Exception ex) { - Error(ex); - return; - } + Addresses = Dns.GetHostAddresses(Request.Host); if(!Addresses.Any()) - Error(new HttpTaskNoAddressesException()); + throw new HttpTaskNoAddressesException(); } private void DoRequest() { - Exception exception = null; + Queue addresses = new(Addresses); - try { - foreach(IPAddress addr in Addresses) { - int tries = 0; - IPEndPoint endPoint = new(addr, Request.Port); + while(addresses.TryDequeue(out IPAddress addr)) { + int tries = 0; + IPEndPoint endPoint = new(addr, Request.Port); - exception = null; + Connection = Connections.GetConnection(Request.Host, endPoint, Request.IsSecure); + + retry: + ++tries; + try { + Request.WriteTo(Connection.Stream, (p, t) => OnUploadProgress?.Invoke(this, p, t)); + break; + } catch(HttpRequestMessageStreamException) { + Connection.Dispose(); Connection = Connections.GetConnection(Request.Host, endPoint, Request.IsSecure); - retry: - ++tries; - try { - Request.WriteTo(Connection.Stream, (p, t) => OnUploadProgress?.Invoke(this, p, t)); - break; - } catch(HttpRequestMessageStreamException ex) { - Connection.Dispose(); - Connection = Connections.GetConnection(Request.Host, endPoint, Request.IsSecure); + if(tries < 2) + goto retry; - if(tries < 2) - goto retry; - - exception = ex; - continue; - } finally { - Connection.MarkUsed(); - } + if(!addresses.Any()) + throw; + } finally { + Connection.MarkUsed(); } - } catch(Exception ex) { - Error(ex); } - if(exception != null) - Error(exception); - else if(Connection == null) - Error(new HttpTaskNoConnectionException()); + if(Connection == null) + throw new HttpTaskNoConnectionException(); } private void DoResponse() { - try { - Response = HttpResponseMessage.ReadFrom(Connection.Stream, (p, t) => OnDownloadProgress?.Invoke(this, p, t)); - } catch(Exception ex) { - Error(ex); - return; - } + Response = HttpResponseMessage.ReadFrom(Connection.Stream, (p, t) => OnDownloadProgress?.Invoke(this, p, t)); if(Response.Connection == HttpConnectionHeader.CLOSE || Response.ProtocolVersion.CompareTo(@"1.1") < 0) Connection.Dispose(); if(Response == null) - Error(new HttpTaskRequestFailedException()); + throw new HttpTaskRequestFailedException(); HttpKeepAliveHeader hkah = Response.Headers.Where(x => x.Name == HttpKeepAliveHeader.NAME).Cast().FirstOrDefault(); if(hkah != null) { diff --git a/Hamakaze/WebSocket/WsBinaryMessage.cs b/Hamakaze/WebSocket/WsBinaryMessage.cs new file mode 100644 index 0000000..172ba4f --- /dev/null +++ b/Hamakaze/WebSocket/WsBinaryMessage.cs @@ -0,0 +1,11 @@ +using System; + +namespace Hamakaze.WebSocket { + public class WsBinaryMessage : WsMessage { + public byte[] Data { get; } + + public WsBinaryMessage(byte[] data) { + Data = data ?? Array.Empty(); + } + } +} diff --git a/Hamakaze/WebSocket/WsBufferedSend.cs b/Hamakaze/WebSocket/WsBufferedSend.cs new file mode 100644 index 0000000..b5e7c2d --- /dev/null +++ b/Hamakaze/WebSocket/WsBufferedSend.cs @@ -0,0 +1,31 @@ +using System; + +namespace Hamakaze.WebSocket { + public class WsBufferedSend : IDisposable { + private WsConnection Connection { get; } + + internal WsBufferedSend(WsConnection conn) { + Connection = conn ?? throw new ArgumentNullException(nameof(conn)); + } + + // + + private bool IsDisposed; + + ~WsBufferedSend() { + DoDispose(); + } + + public void Dispose() { + DoDispose(); + GC.SuppressFinalize(this); + } + + private void DoDispose() { + if(IsDisposed) + return; + IsDisposed = true; + + } + } +} diff --git a/Hamakaze/WebSocket/WsClient.cs b/Hamakaze/WebSocket/WsClient.cs new file mode 100644 index 0000000..5bb788e --- /dev/null +++ b/Hamakaze/WebSocket/WsClient.cs @@ -0,0 +1,138 @@ +using System; +using System.Threading; + +namespace Hamakaze.WebSocket { + public class WsClient : IDisposable { + public WsConnection Connection { get; } + public bool IsRunning { get; private set; } = true; + + private Thread ReadThread { get; } + private Action MessageHandler { get; } + private Action ExceptionHandler { get; } + + private Mutex SendLock { get; } + + public WsClient( + WsConnection connection, + Action messageHandler, + Action exceptionHandler + ) { + Connection = connection ?? throw new ArgumentNullException(nameof(connection)); + MessageHandler = messageHandler ?? throw new ArgumentNullException(nameof(messageHandler)); + ExceptionHandler = exceptionHandler ?? throw new ArgumentNullException(nameof(exceptionHandler)); + + SendLock = new(); + + ReadThread = new(ReadThreadBody) { IsBackground = true }; + ReadThread.Start(); + } + + private void ReadThreadBody() { + try { + while(IsRunning) + MessageHandler(Connection.Receive()); + } catch(Exception ex) { + IsRunning = false; + ExceptionHandler(ex); + } + } + + public void Send(string text) { + Connection.Send(text); + } + + public void Send(object obj) { + if(obj == null) + throw new ArgumentNullException(nameof(obj)); + + Connection.Send(obj.ToString()); + } + + public void Send(ReadOnlySpan data) { + Connection.Send(data); + } + + public void Send(byte[] buffer, int offset, int count) { + if(buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + Connection.Send(buffer.AsSpan(offset, count)); + } + + public void Ping() { + Connection.Ping(); + } + + public void Ping(ReadOnlySpan data) { + Connection.Ping(data); + } + + public void Ping(byte[] buffer, int offset, int length) { + if(buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + Connection.Ping(buffer.AsSpan(offset, length)); + } + + public void Pong() { + Connection.Pong(); + } + + public void Pong(ReadOnlySpan data) { + Connection.Pong(data); + } + + public void Pong(byte[] buffer, int offset, int length) { + if(buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + Pong(buffer.AsSpan(offset, length)); + } + + public void Close() { + Connection.Close(WsCloseReason.NormalClosure); + } + + public void CloseEmpty() { + Connection.CloseEmpty(); + } + + public void Close(string reason) { + Connection.Close(WsCloseReason.NormalClosure, reason); + } + + public void Close(byte[] buffer, int offset, int length) { + if(buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + Connection.Close(buffer.AsSpan(offset, length)); + } + + public void Close(WsCloseReason code, byte[] buffer, int offset, int length) { + if(buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + Connection.Close(code, buffer.AsSpan(offset, length)); + } + + private bool IsDisposed; + + ~WsClient() { + DoDispose(); + } + + public void Dispose() { + DoDispose(); + GC.SuppressFinalize(this); + } + + private void DoDispose() { + if(IsDisposed) + return; + IsDisposed = true; + + SendLock.Dispose(); + Connection.Dispose(); + } + } +} diff --git a/Hamakaze/WebSocket/WsCloseMessage.cs b/Hamakaze/WebSocket/WsCloseMessage.cs new file mode 100644 index 0000000..e0a584a --- /dev/null +++ b/Hamakaze/WebSocket/WsCloseMessage.cs @@ -0,0 +1,36 @@ +using System; +using System.Text; + +namespace Hamakaze.WebSocket { + public class WsCloseMessage : WsMessage { + public WsCloseReason Reason { get; } + public string ReasonPhrase { get; } + public byte[] Data { get; } + + public WsCloseMessage(WsCloseReason reason) { + Reason = reason; + ReasonPhrase = string.Empty; + Data = Array.Empty(); + } + + public WsCloseMessage(byte[] data) { + if(data == null) { + Reason = WsCloseReason.NoStatus; + ReasonPhrase = string.Empty; + Data = Array.Empty(); + } else { + Reason = (WsCloseReason)WsUtils.ToU16(data); + Data = data; + + if(data.Length > 2) + try { + ReasonPhrase = Encoding.UTF8.GetString(data, 2, data.Length - 2); + } catch { + ReasonPhrase = string.Empty; + } + else + ReasonPhrase = string.Empty; + } + } + } +} diff --git a/Hamakaze/WebSocket/WsCloseReason.cs b/Hamakaze/WebSocket/WsCloseReason.cs new file mode 100644 index 0000000..5b6e5a0 --- /dev/null +++ b/Hamakaze/WebSocket/WsCloseReason.cs @@ -0,0 +1,16 @@ +namespace Hamakaze.WebSocket { + public enum WsCloseReason : ushort { + NormalClosure = 1000, + GoingAway = 1001, + ProtocolError = 1002, + InvalidData = 1003, + NoStatus = 1005, // virtual -> no data in close frame + AbnormalClosure = 1006, // virtual -> connection dropped + MalformedData = 1007, + PolicyViolation = 1008, + FrameTooLarge = 1009, + MissingExtension = 1010, + UnexpectedCondition = 1011, + TlsHandshakeFailed = 1015, // virtual -> obvious + } +} diff --git a/Hamakaze/WebSocket/WsConnection.cs b/Hamakaze/WebSocket/WsConnection.cs new file mode 100644 index 0000000..164dc13 --- /dev/null +++ b/Hamakaze/WebSocket/WsConnection.cs @@ -0,0 +1,468 @@ +using System; +using System.IO; +using System.Net.Security; +using System.Security.Cryptography; +using System.Text; + +// TODO: optimisations with newer .net feature to reduce memory copying +// i think we're generally aware of how much data we're shoving around +// so memorystream can be considered overkill + +// Should there be internal mutexing on the socket? (leaning towards no) + +// Should all external stream handling be moved to WsClient? +// - IDEA: Buffered send "session" class. +// Would require exposing the raw Write methods +// but i suppose that's what "internal" exists for + +namespace Hamakaze.WebSocket { + public class WsConnection : IDisposable { + public Stream Stream { get; } + + public bool IsSecure { get; } + public bool IsClosed { get; private set; } + + private const int BUFFER_SIZE = 0x2000; + private const byte MASK_FLAG = 0x80; + private const int MASK_SIZE = 4; + + private WsOpcode FragmentedType = 0; + private MemoryStream FragmentedStream; + + public WsConnection(Stream stream) { + Stream = stream ?? throw new ArgumentNullException(nameof(stream)); + IsSecure = stream is SslStream; + } + + private static byte[] GenerateMask() { + return RandomNumberGenerator.GetBytes(MASK_SIZE); + } + + private void StrictRead(byte[] buffer, int offset, int length) { + int read = Stream.Read(buffer, offset, length); + if(read < length) + throw new Exception(@"Was unable to read the requested amount of data."); + } + + private (WsOpcode opcode, long length, bool isFinal, byte[] mask) ReadFrameHeader() { + byte[] buffer = new byte[8]; + StrictRead(buffer, 0, 2); + + WsOpcode opcode = (WsOpcode)(buffer[0] & 0x0F); + bool isFinal = (buffer[0] & (byte)WsOpcode.FlagFinal) > 0; + + if(opcode >= WsOpcode.CtrlClose && !isFinal) + throw new WsInvalidOpcodeException((WsOpcode)buffer[0]); + + bool isControl = (opcode & WsOpcode.CtrlClose) > 0; + + if(isControl && !isFinal) + throw new WsInvalidControlFrameException(@"fragmented"); + + bool isMasked = (buffer[1] & MASK_FLAG) > 0; + + // this may look stupid and you'd be correct but it's better than the stack of casts + // i'd otherwise have to do otherwise because c# converts everything back to int32 + buffer[1] &= 0x7F; + long length = buffer[1]; + + if(length == 126) { + StrictRead(buffer, 0, 2); + length = WsUtils.ToU16(buffer); + } else if(length == 127) { + StrictRead(buffer, 0, 8); + length = WsUtils.ToI64(buffer); + } + + if(isControl && length > 125) + throw new WsInvalidControlFrameException(@"too large"); + + // should there be a sanity check on the length of frames? + // i seriously don't understand the rationale behind both + // having a framing system but then also supporting frame lengths + // of 2^63, feels like 2^16 per frame would be a fine max. + if(length < 0 || length > long.MaxValue) + throw new WsInvalidFrameSizeException(length); + + byte[] mask = null; + + if(isMasked) { + StrictRead(buffer, 0, MASK_SIZE); + mask = buffer; + } + + return (opcode, length, isFinal, mask); + } + + private long ReadFrameBody(Stream target, long length, byte[] mask, long offset = 0) { + if(target == null) + throw new ArgumentNullException(nameof(target)); + if(!target.CanWrite) + throw new ArgumentException(@"Target stream is not writable.", nameof(target)); + + bool isMasked = mask != null; + + int read; + int take = length > BUFFER_SIZE ? BUFFER_SIZE : (int)length; + byte[] buffer = new byte[take]; + + while(length > 0) { + read = Stream.Read(buffer, 0, take); + + if(isMasked) + for(int i = 0; i < read; ++i) + buffer[i] ^= mask[offset++ % MASK_SIZE]; + + target.Write(buffer, 0, read); + + offset += read; + length -= read; + + if(take > length) + take = (int)length; + } + + return offset; + } + + private WsMessage ReadFrame() { + (WsOpcode opcode, long length, bool isFinal, byte[] mask) = ReadFrameHeader(); + + if(opcode is not WsOpcode.DataContinue + and not WsOpcode.DataBinary + and not WsOpcode.DataText + and not WsOpcode.CtrlClose + and not WsOpcode.CtrlPing + and not WsOpcode.CtrlPong) + throw new WsUnsupportedOpcodeException(opcode); + + bool hasBody = length > 0; + bool isContinue = opcode == WsOpcode.DataContinue; + bool canFragment = (opcode & WsOpcode.CtrlClose) == 0; + + MemoryStream bodyStream = null; + + if(hasBody) { + if(canFragment) { + if(isContinue) { + if(FragmentedType == 0) + throw new WsUnexpectedContinueException(); + + opcode = FragmentedType; + + if(FragmentedStream == null) + FragmentedStream = bodyStream = new(); + else + bodyStream = FragmentedStream; + } else { + if(FragmentedType != 0) + throw new WsUnexpectedDataException(); + + if(isFinal) + bodyStream = new(); + else { + FragmentedType = opcode; + FragmentedStream = bodyStream = new(); + } + } + } else + bodyStream = new(); + + ReadFrameBody(bodyStream, length, mask); + } + + WsMessage msg; + + if(isFinal) { + if(canFragment && isContinue) { + FragmentedType = 0; + FragmentedStream = null; + } + + byte[] body = null; + + if(bodyStream != null) { + if(bodyStream.Length > 0) + body = bodyStream.ToArray(); + bodyStream.Dispose(); + } + + switch(opcode) { + case WsOpcode.DataText: + msg = new WsTextMessage(body); + break; + + case WsOpcode.DataBinary: + msg = new WsBinaryMessage(body); + break; + + case WsOpcode.CtrlClose: + msg = new WsCloseMessage(body); + break; + + case WsOpcode.CtrlPing: + msg = new WsPingMessage(body); + break; + + case WsOpcode.CtrlPong: + msg = new WsPongMessage(body); + break; + + default: // fallback, if we end up here something is very fucked + throw new WsUnsupportedOpcodeException(opcode); + } + } else msg = null; + + return msg; + } + + public WsMessage Receive() { + WsMessage msg; + while((msg = ReadFrame()) == null); + return msg; + } + + private void WriteFrameHeader(WsOpcode opcode, long length, bool isFinal, byte[] mask = null) { + bool shouldMask = mask != null; + + if(isFinal) + opcode |= WsOpcode.FlagFinal; + + Stream.WriteByte((byte)opcode); + + byte bLen1 = 0; + if(shouldMask) + bLen1 |= MASK_FLAG; + + byte[] bLenBuff = WsUtils.FromI64(length); + if(length < 126) { + Stream.WriteByte((byte)(bLen1 | bLenBuff[7])); + } else if(length <= ushort.MaxValue) { + Stream.WriteByte((byte)(bLen1 | 126)); + Stream.Write(bLenBuff, 6, 2); + } else { + Stream.WriteByte((byte)(bLen1 | 127)); + Stream.Write(bLenBuff, 0, 8); + } + + if(shouldMask) + Stream.Write(mask, 0, MASK_SIZE); + } + + private long WriteFrameBody(ReadOnlySpan body, byte[] mask = null, long offset = 0) { + if(mask != null) { + byte[] masked = new byte[body.Length]; + + for(int i = 0; i < body.Length; ++i) + masked[i] = (byte)(body[i] ^ mask[offset++ % MASK_SIZE]); + + body = masked; + } + + Stream.Write(body); + + return offset; + } + + private long WriteFrameBody(Stream body, byte[] mask = null, long offset = 0) { + bool shouldMask = mask != null; + + int read; + byte[] buffer = new byte[BUFFER_SIZE]; + while((read = body.Read(buffer, 0, BUFFER_SIZE)) > 0) + offset = WriteFrameBody(buffer.AsSpan(0, read), mask, offset); + + return offset; + } + + private void WriteFrame(WsOpcode opcode, ReadOnlySpan body, bool isFinal) { + byte[] mask = GenerateMask(); + WriteFrameHeader(opcode, body.Length, isFinal, mask); + if(body.Length > 0) + WriteFrameBody(body, mask); + Stream.Flush(); + } + + private void Write(WsOpcode opcode, ReadOnlySpan body) { + if(body.Length > 0xFFFF) { + WriteFrame(opcode, body.Slice(0, 0xFFFF), false); + body = body.Slice(0xFFFF); + + while(body.Length > 0xFFFF) { + WriteFrame(WsOpcode.DataContinue, body.Slice(0, 0xFFFF), false); + body = body.Slice(0xFFFF); + } + + WriteFrame(WsOpcode.DataContinue, body, true); + } else + WriteFrame(opcode, body, true); + } + + private void Write(WsOpcode opcode, Stream stream) { + if(stream == null) + throw new ArgumentNullException(nameof(stream)); + if(!stream.CanRead) + throw new ArgumentException(@"Provided stream cannot be read.", nameof(stream)); + + int read; + byte[] buffer = new byte[BUFFER_SIZE]; + + while((read = stream.Read(buffer, 0, BUFFER_SIZE)) > 0) { + WriteFrame(opcode, buffer.AsSpan(0, read), false); + + if(opcode != WsOpcode.DataContinue) + opcode = WsOpcode.DataContinue; + } + + // this kinda fucking sucks + WriteFrame(WsOpcode.CtrlClose, ReadOnlySpan.Empty, true); + } + + private void Write(WsOpcode opcode, Stream stream, int length) { + if(stream == null) + throw new ArgumentNullException(nameof(stream)); + if(!stream.CanRead) + throw new ArgumentException(@"Provided stream cannot be read.", nameof(stream)); + + int read; + byte[] buffer = new byte[BUFFER_SIZE]; + + if(length > BUFFER_SIZE) { + int take = BUFFER_SIZE; + + while((read = stream.Read(buffer, 0, take)) > 0) { + WriteFrame(opcode, buffer.AsSpan(0, read), false); + + if(opcode != WsOpcode.DataContinue) + opcode = WsOpcode.DataContinue; + + length -= read; + if(take > length) + take = length; + } + + // feel like there'd be a better way to do this + // but i feel like assuming that any successful read with something + // still coming (read == BUFFER_SIZE) will bite me in the ass later somehow + WriteFrame(WsOpcode.CtrlClose, Span.Empty, true); + } else { + read = stream.Read(buffer, 0, BUFFER_SIZE); + if(read > 0) + WriteFrame(WsOpcode.DataBinary, buffer.AsSpan(0, read), true); + } + } + + public void Send(string text) + => Write(WsOpcode.DataText, Encoding.UTF8.GetBytes(text)); + + public void Send(ReadOnlySpan buffer) + => Write(WsOpcode.DataBinary, buffer); + + public void Send(Stream source) + => Write(WsOpcode.DataBinary, source); + + public void Send(Stream source, int count) + => Write(WsOpcode.DataBinary, source, count); + + private void WriteControlFrame(WsOpcode opcode) { + WriteFrameHeader(opcode, 0, true, GenerateMask()); + Stream.Flush(); + } + + private void WriteControlFrame(WsOpcode opcode, ReadOnlySpan buffer) { + if(buffer.Length > 125) + throw new ArgumentException(@"Data may not be more than 125 bytes.", nameof(buffer)); + + byte[] mask = GenerateMask(); + WriteFrameHeader(opcode, buffer.Length, true, mask); + WriteFrameBody(buffer, mask); + Stream.Flush(); + } + + public void Ping() + => WriteControlFrame(WsOpcode.CtrlPing); + + public void Ping(ReadOnlySpan buffer) + => WriteControlFrame(WsOpcode.CtrlPing, buffer); + + public void Pong() + => WriteControlFrame(WsOpcode.CtrlPong); + + public void Pong(ReadOnlySpan buffer) + => WriteControlFrame(WsOpcode.CtrlPong, buffer); + + public void CloseEmpty() { + if(IsClosed) + return; + IsClosed = true; + + WriteControlFrame(WsOpcode.CtrlClose); + } + + public void Close(ReadOnlySpan buffer) { + if(IsClosed) + return; + IsClosed = true; + + WriteControlFrame(WsOpcode.CtrlClose, buffer); + } + + public void Close(WsCloseReason code) + => Close(WsUtils.FromU16((ushort)code)); + + public void Close(WsCloseReason code, ReadOnlySpan reason) { + if(reason.Length > 123) + throw new ArgumentException(@"Reason may not be more than 123 bytes.", nameof(reason)); + + if(IsClosed) + return; + IsClosed = true; + + byte[] mask = GenerateMask(); + WriteFrameHeader(WsOpcode.CtrlClose, 2 + reason.Length, true, mask); + WriteFrameBody(WsUtils.FromU16((ushort)code), mask); + WriteFrameBody(reason, mask, 2); + Stream.Flush(); + } + + public void Close(WsCloseReason code, string reason) { + if(string.IsNullOrEmpty(reason)) { + Close(code); + return; + } + + int length = Encoding.UTF8.GetByteCount(reason); + if(length > 123) + throw new ArgumentException(@"Reason string may not exceed 123 bytes in length.", nameof(reason)); + + if(IsClosed) + return; + IsClosed = true; + + byte[] mask = GenerateMask(); + WriteFrameHeader(WsOpcode.CtrlClose, 2 + reason.Length, true, mask); + WriteFrameBody(WsUtils.FromU16((ushort)code), mask); + WriteFrameBody(Encoding.UTF8.GetBytes(reason), mask, 2); + Stream.Flush(); + } + + private bool IsDisposed; + + ~WsConnection() { + DoDispose(); + } + + public void Dispose() { + DoDispose(); + GC.SuppressFinalize(this); + } + + private void DoDispose() { + if(IsDisposed) + return; + IsDisposed = true; + + Stream.Dispose(); + } + } +} diff --git a/Hamakaze/WebSocket/WsException.cs b/Hamakaze/WebSocket/WsException.cs new file mode 100644 index 0000000..50f0f5a --- /dev/null +++ b/Hamakaze/WebSocket/WsException.cs @@ -0,0 +1,29 @@ +namespace Hamakaze.WebSocket { + public class WsException : HttpException { + public WsException(string message) : base(message) { } + } + + public class WsInvalidOpcodeException : WsException { + public WsInvalidOpcodeException(WsOpcode opcode) : base($@"An invalid WebSocket opcode was encountered: {opcode}.") { } + } + + public class WsUnsupportedOpcodeException : WsException { + public WsUnsupportedOpcodeException(WsOpcode opcode) : base($@"An unsupported WebSocket opcode was encountered: {opcode}.") { } + } + + public class WsInvalidFrameSizeException : WsException { + public WsInvalidFrameSizeException(long size) : base($@"WebSocket frame size is too large: {size} bytes.") { } + } + + public class WsUnexpectedContinueException : WsException { + public WsUnexpectedContinueException() : base(@"A WebSocket continue frame was issued but there is nothing to continue.") { } + } + + public class WsUnexpectedDataException : WsException { + public WsUnexpectedDataException() : base(@"A WebSocket data frame was issued while a fragmented frame is being constructed.") { } + } + + public class WsInvalidControlFrameException : WsException { + public WsInvalidControlFrameException(string variant) : base($@"An invalid WebSocket control frame was encountered: {variant}") { } + } +} diff --git a/Hamakaze/WebSocket/WsMessage.cs b/Hamakaze/WebSocket/WsMessage.cs new file mode 100644 index 0000000..ebb9344 --- /dev/null +++ b/Hamakaze/WebSocket/WsMessage.cs @@ -0,0 +1,5 @@ +namespace Hamakaze.WebSocket { + public abstract class WsMessage { + // nothing, lol + } +} diff --git a/Hamakaze/WebSocket/WsOpcode.cs b/Hamakaze/WebSocket/WsOpcode.cs new file mode 100644 index 0000000..4491160 --- /dev/null +++ b/Hamakaze/WebSocket/WsOpcode.cs @@ -0,0 +1,13 @@ +namespace Hamakaze.WebSocket { + public enum WsOpcode : byte { + DataContinue = 0x00, + DataText = 0x01, + DataBinary = 0x02, + + CtrlClose = 0x08, + CtrlPing = 0x09, + CtrlPong = 0x0A, + + FlagFinal = 0x80, + } +} diff --git a/Hamakaze/WebSocket/WsPingMessage.cs b/Hamakaze/WebSocket/WsPingMessage.cs new file mode 100644 index 0000000..066d199 --- /dev/null +++ b/Hamakaze/WebSocket/WsPingMessage.cs @@ -0,0 +1,11 @@ +using System; + +namespace Hamakaze.WebSocket { + public class WsPingMessage : WsMessage { + public byte[] Data { get; } + + public WsPingMessage(byte[] data) { + Data = data ?? Array.Empty(); + } + } +} diff --git a/Hamakaze/WebSocket/WsPongMessage.cs b/Hamakaze/WebSocket/WsPongMessage.cs new file mode 100644 index 0000000..54d44bd --- /dev/null +++ b/Hamakaze/WebSocket/WsPongMessage.cs @@ -0,0 +1,11 @@ +using System; + +namespace Hamakaze.WebSocket { + public class WsPongMessage : WsMessage { + public byte[] Data { get; } + + public WsPongMessage(byte[] data) { + Data = data ?? Array.Empty(); + } + } +} diff --git a/Hamakaze/WebSocket/WsTextMessage.cs b/Hamakaze/WebSocket/WsTextMessage.cs new file mode 100644 index 0000000..fb41d76 --- /dev/null +++ b/Hamakaze/WebSocket/WsTextMessage.cs @@ -0,0 +1,14 @@ +using System.Text; + +namespace Hamakaze.WebSocket { + public class WsTextMessage : WsMessage { + public string Text { get; } + + public WsTextMessage(byte[] data) { + if(data?.Length > 0) + Text = Encoding.UTF8.GetString(data); + else + Text = string.Empty; + } + } +} diff --git a/Hamakaze/WebSocket/WsUtils.cs b/Hamakaze/WebSocket/WsUtils.cs new file mode 100644 index 0000000..ce2b319 --- /dev/null +++ b/Hamakaze/WebSocket/WsUtils.cs @@ -0,0 +1,38 @@ +using System; + +namespace Hamakaze.WebSocket { + internal static class WsUtils { + public static byte[] FromU16(ushort num) { + byte[] buff = BitConverter.GetBytes(num); + if(BitConverter.IsLittleEndian) + Array.Reverse(buff); + return buff; + } + + public static ushort ToU16(ReadOnlySpan buffer) { + if(BitConverter.IsLittleEndian) + buffer = new byte[2] { + buffer[1], buffer[0], + }; + + return BitConverter.ToUInt16(buffer); + } + + public static byte[] FromI64(long num) { + byte[] buff = BitConverter.GetBytes(num); + if(BitConverter.IsLittleEndian) + Array.Reverse(buff); + return buff; + } + + public static long ToI64(ReadOnlySpan buffer) { + if(BitConverter.IsLittleEndian) + buffer = new byte[8] { + buffer[7], buffer[6], buffer[5], buffer[4], + buffer[3], buffer[2], buffer[1], buffer[0], + }; + + return BitConverter.ToInt64(buffer); + } + } +}