diff --git a/Example/Program.cs b/Example/Program.cs index 137f0e4f..d1f25694 100644 --- a/Example/Program.cs +++ b/Example/Program.cs @@ -73,8 +73,8 @@ namespace Example { ThreadPool.QueueUserWorkItem(notifyMsg); - using (var ws = new WebSocket("ws://echo.websocket.org", "echo")) - //using (var ws = new WebSocket("wss://echo.websocket.org", "echo")) + using (var ws = new WebSocket("ws://echo.websocket.org")) + //using (var ws = new WebSocket("wss://echo.websocket.org")) //using (var ws = new WebSocket("ws://localhost:4649")) //using (var ws = new WebSocket("ws://localhost:4649/Echo")) //using (var ws = new WebSocket("wss://localhost:4649/Echo")) diff --git a/websocket-sharp/Ext.cs b/websocket-sharp/Ext.cs index dcd82a66..479a3adb 100644 --- a/websocket-sharp/Ext.cs +++ b/websocket-sharp/Ext.cs @@ -258,6 +258,16 @@ namespace WebSocketSharp : null; } + internal static string CheckIfValidProtocols (this string [] protocols) + { + return protocols.Contains ( + protocol => protocol.Length == 0 || !protocol.IsToken ()) + ? "Contains an invalid value." + : protocols.ContainsTwice () + ? "Contains a value twice." + : null; + } + internal static string CheckIfValidSendData (this byte [] data) { return data == null @@ -334,6 +344,36 @@ namespace WebSocketSharp : stream.ToByteArray (); } + internal static bool Contains ( + this IEnumerable source, Func comparer) + { + foreach (T value in source) + if (comparer (value)) + return true; + + return false; + } + + internal static bool ContainsTwice (this string [] values) + { + var len = values.Length; + + Func contains = null; + contains = index => { + if (index < len - 1) { + for (var i = index + 1; i < len; i++) + if (values [i] == values [index]) + return true; + + return contains (++index); + } + + return false; + }; + + return contains (0); + } + internal static T [] Copy (this T [] src, long length) { var dest = new T [length]; diff --git a/websocket-sharp/WebSocket.cs b/websocket-sharp/WebSocket.cs index 754c57fd..4bf27827 100644 --- a/websocket-sharp/WebSocket.cs +++ b/websocket-sharp/WebSocket.cs @@ -38,7 +38,6 @@ using System.Collections.Generic; using System.Collections.Specialized; using System.Diagnostics; using System.IO; -using System.Linq; using System.Net.Sockets; using System.Net.Security; using System.Security.Cryptography; @@ -89,7 +88,7 @@ namespace WebSocketSharp private string _origin; private bool _preAuth; private string _protocol; - private string _protocols; + private string [] _protocols; private volatile WebSocketState _readyState; private AutoResetEvent _receivePong; private bool _secure; @@ -148,10 +147,19 @@ namespace WebSocketSharp /// /// /// An array of that contains the WebSocket subprotocols - /// if any. + /// if any. Each value of must be a token defined + /// in RFC 2616. /// /// - /// is invalid. + /// + /// is invalid. + /// + /// + /// -or- + /// + /// + /// is invalid. + /// /// /// /// is . @@ -165,7 +173,13 @@ namespace WebSocketSharp if (!url.TryCreateWebSocketUri (out _uri, out msg)) throw new ArgumentException (msg, "url"); - _protocols = protocols.ToString (", "); + if (protocols != null && protocols.Length > 0) { + msg = protocols.CheckIfValidProtocols (); + if (msg != null) + throw new ArgumentException (msg, "protocols"); + + _protocols = protocols; + } _base64Key = CreateBase64Key (); _client = true; @@ -677,21 +691,22 @@ namespace WebSocketSharp private string checkIfValidHandshakeResponse (HandshakeResponse response) { var headers = response.Headers; - - string accept, version; return response.IsUnauthorized ? String.Format ( - "An HTTP {0} authorization is required.", + "HTTP {0} authorization is required.", response.AuthChallenge.Scheme) : !response.IsWebSocketResponse - ? "Not WebSocket connection response to the connection request." - : (accept = headers ["Sec-WebSocket-Accept"]) == null || - accept != CreateResponseKey (_base64Key) + ? "Not WebSocket connection response." + : !validateSecWebSocketAcceptHeader ( + headers ["Sec-WebSocket-Accept"]) ? "Invalid Sec-WebSocket-Accept header." - : (version = headers ["Sec-WebSocket-Version"]) != null && - version != _version - ? "Invalid Sec-WebSocket-Version header." - : null; + : !validateSecWebSocketProtocolHeader ( + headers ["Sec-WebSocket-Protocol"]) + ? "Invalid Sec-WebSocket-Protocol header." + : !validateSecWebSocketVersionHeader ( + headers ["Sec-WebSocket-Version"]) + ? "Invalid Sec-WebSocket-Version header." + : null; } private void close (CloseStatusCode code, string reason, bool wait) @@ -894,8 +909,8 @@ namespace WebSocketSharp headers ["Sec-WebSocket-Key"] = _base64Key; - if (!_protocols.IsNullOrEmpty ()) - headers ["Sec-WebSocket-Protocol"] = _protocols; + if (_protocols != null) + headers ["Sec-WebSocket-Protocol"] = _protocols.ToString (", "); var extensions = createExtensionsRequest (); if (extensions.Length > 0) @@ -955,21 +970,17 @@ namespace WebSocketSharp { setClientStream (); var res = sendHandshakeRequest (); - var err = checkIfValidHandshakeResponse (res); - if (err != null) { - _logger.Error (err); + var msg = checkIfValidHandshakeResponse (res); + if (msg != null) { + _logger.Error (msg); - var msg = "An error has occurred while connecting."; + msg = "An error has occurred while connecting."; error (msg); close (CloseStatusCode.ABNORMAL, msg, false); return false; } - var protocol = res.Headers ["Sec-WebSocket-Protocol"]; - if (!protocol.IsNullOrEmpty ()) - _protocol = protocol; - processRespondedExtensions (res.Headers ["Sec-WebSocket-Extensions"]); var cookies = res.Cookies; @@ -1364,6 +1375,32 @@ namespace WebSocketSharp host == expected; } + // As client + private bool validateSecWebSocketAcceptHeader (string value) + { + return value != null && value == CreateResponseKey (_base64Key); + } + + // As client + private bool validateSecWebSocketProtocolHeader (string value) + { + if (value == null) + return _protocols == null; + + if (_protocols == null || + !_protocols.Contains (protocol => protocol == value)) + return false; + + _protocol = value; + return true; + } + + // As client + private bool validateSecWebSocketVersionHeader (string value) + { + return value == null || value == _version; + } + #endregion #region Internal Methods