diff --git a/websocket-sharp/WebSocket.cs b/websocket-sharp/WebSocket.cs index 7b76cecc..48da92e8 100644 --- a/websocket-sharp/WebSocket.cs +++ b/websocket-sharp/WebSocket.cs @@ -69,7 +69,7 @@ namespace WebSocketSharp #region Private Fields private AuthenticationChallenge _authChallenge; - private string _base64key; + private string _base64Key; private RemoteCertificateValidationCallback _certValidationCallback; private bool _client; @@ -168,7 +168,7 @@ namespace WebSocketSharp _protocols = protocols.ToString (", "); - _base64key = createBase64Key (); + _base64Key = CreateBase64Key (); _client = true; _logger = new Logger (); _secure = _uri.Scheme == "wss"; @@ -487,22 +487,26 @@ namespace WebSocketSharp // As server private bool acceptHandshake () { - _logger.Debug (String.Format ( - "A WebSocket connection request from {0}:\n{1}", _context.UserEndPoint, _context)); + _logger.Debug ( + String.Format ( + "A WebSocket connection request from {0}:\n{1}", + _context.UserEndPoint, + _context)); - if (!validateConnectionRequest (_context)) - { - _logger.Error ("An invalid WebSocket connection request."); + var err = checkIfValidHandshakeRequest (_context); + if (err != null) { + _logger.Error (err); - error ("An error has occurred while handshaking."); + error ("An error has occurred while connecting."); Close (HttpStatusCode.BadRequest); return false; } - _base64key = _context.SecWebSocketKey; + _base64Key = _context.SecWebSocketKey; - if (_protocol.Length > 0 && !_context.Headers.Contains ("Sec-WebSocket-Protocol", _protocol)) + if (_protocol.Length > 0 && + !_context.Headers.Contains ("Sec-WebSocket-Protocol", _protocol)) _protocol = String.Empty; var extensions = _context.Headers ["Sec-WebSocket-Extensions"]; @@ -512,6 +516,45 @@ namespace WebSocketSharp return send (createHandshakeResponse ()); } + // As server + private string checkIfValidHandshakeRequest (WebSocketContext context) + { + string key, version; + return !context.IsWebSocketRequest + ? "Not WebSocket connection request." + : !validateHostHeader (context.Host) + ? "Invalid Host header." + : (key = context.SecWebSocketKey) == null || key.Length == 0 + ? "Invalid Sec-WebSocket-Key header." + : (version = context.SecWebSocketVersion) == null || + version != _version + ? "Invalid Sec-WebSocket-Version header." + : !validateCookies (context.CookieCollection, _cookies) + ? "Invalid Cookies." + : null; + } + + // As client + private string checkIfValidHandshakeResponse (HandshakeResponse response) + { + var headers = response.Headers; + + string accept, version; + return response.IsUnauthorized + ? String.Format ( + "An HTTP {0} authorization is required.", + response.AuthChallenge.Scheme) + : !response.IsWebSocketResponse + ? "Not WebSocket connection response to the connection request." + : (accept = headers ["Sec-WebSocket-Accept"]) == null || + accept != CreateResponseKey (_base64Key) + ? "Invalid Sec-WebSocket-Accept header." + : (version = headers ["Sec-WebSocket-Version"]) != null && + version != _version + ? "Invalid Sec-WebSocket-Version header." + : null; + } + private void close (CloseStatusCode code, string reason, bool wait) { close (new PayloadData (((ushort) code).Append (reason)), !code.IsReserved (), wait); @@ -649,7 +692,7 @@ namespace WebSocketSharp return true; } - private void connect () + private bool connect () { lock (_forConnect) { if (IsOpened) { @@ -657,30 +700,24 @@ namespace WebSocketSharp _logger.Error (msg); error (msg); - return; + return false; } try { - if (_client ? doHandshake () : acceptHandshake ()) - open (); + if (_client ? doHandshake () : acceptHandshake ()) { + _readyState = WebSocketState.OPEN; + return true; + } } catch (Exception ex) { processException ( - ex, "An exception has occurred while connecting or opening."); + ex, "An exception has occurred while connecting."); } + + return false; } } - // As client - private static string createBase64Key () - { - var src = new byte [16]; - var rand = new Random (); - rand.NextBytes (src); - - return Convert.ToBase64String (src); - } - // As client private string createExtensionsRequest () { @@ -709,7 +746,7 @@ namespace WebSocketSharp if (!_origin.IsNullOrEmpty ()) headers ["Origin"] = _origin; - headers ["Sec-WebSocket-Key"] = _base64key; + headers ["Sec-WebSocket-Key"] = _base64Key; if (!_protocols.IsNullOrEmpty ()) headers ["Sec-WebSocket-Protocol"] = _protocols; @@ -744,7 +781,7 @@ namespace WebSocketSharp var res = new HandshakeResponse (HttpStatusCode.SwitchingProtocols); var headers = res.Headers; - headers ["Sec-WebSocket-Accept"] = createResponseKey (); + headers ["Sec-WebSocket-Accept"] = CreateResponseKey (_base64Key); if (_protocol.Length > 0) headers ["Sec-WebSocket-Protocol"] = _protocol; @@ -767,32 +804,16 @@ namespace WebSocketSharp return res; } - private string createResponseKey () - { - var buffer = new StringBuilder (_base64key, 64); - buffer.Append (_guid); - SHA1 sha1 = new SHA1CryptoServiceProvider (); - var src = sha1.ComputeHash (Encoding.UTF8.GetBytes (buffer.ToString ())); - - return Convert.ToBase64String (src); - } - // As client private bool doHandshake () { setClientStream (); var res = sendHandshakeRequest (); - var err = res.IsUnauthorized - ? String.Format ("An HTTP {0} authorization is required.", res.AuthChallenge.Scheme) - : !validateConnectionResponse (res) - ? "An invalid response to this WebSocket connection request." - : null; - - if (err != null) - { + var err = checkIfValidHandshakeResponse (res); + if (err != null) { _logger.Error (err); - var msg = "An error has occurred while handshaking."; + var msg = "An error has occurred while connecting."; error (msg); close (CloseStatusCode.ABNORMAL, msg, false); @@ -831,10 +852,14 @@ namespace WebSocketSharp private void open () { - _readyState = WebSocketState.OPEN; - - OnOpen.Emit (this, EventArgs.Empty); - startReceiving (); + try { + OnOpen.Emit (this, EventArgs.Empty); + startReceiving (); + } + catch (Exception ex) { + processException ( + ex, "An exception has occurred while opening."); + } } private bool processCloseFrame (WsFrame frame) @@ -1270,13 +1295,15 @@ namespace WebSocketSharp private void startReceiving () { + if (_readyState != WebSocketState.OPEN) + return; + _exitReceiving = new AutoResetEvent (false); _receivePong = new AutoResetEvent (false); Action receive = null; receive = () => _stream.ReadFrameAsync ( - frame => - { + frame => { if (processFrame (frame)) receive (); else @@ -1288,26 +1315,6 @@ namespace WebSocketSharp receive (); } - // As server - private bool validateConnectionRequest (WebSocketContext context) - { - string version; - return context.IsWebSocketRequest && - validateHostHeader (context.Host) && - !context.SecWebSocketKey.IsNullOrEmpty () && - ((version = context.SecWebSocketVersion) != null && version == _version) && - validateCookies (context.CookieCollection, _cookies); - } - - // As client - private bool validateConnectionResponse (HandshakeResponse response) - { - string accept, version; - return response.IsWebSocketResponse && - ((accept = response.Headers ["Sec-WebSocket-Accept"]) != null && accept == createResponseKey ()) && - ((version = response.Headers ["Sec-WebSocket-Version"]) == null || version == _version); - } - // As server private bool validateCookies (CookieCollection request, CookieCollection response) { @@ -1377,15 +1384,37 @@ namespace WebSocketSharp internal void ConnectAsServer () { try { - if (acceptHandshake ()) + if (acceptHandshake ()) { + _readyState = WebSocketState.OPEN; open (); + } } catch (Exception ex) { processException ( - ex, "An exception has occurred while connecting or opening."); + ex, "An exception has occurred while connecting."); } } + // As client + internal static string CreateBase64Key () + { + var src = new byte [16]; + var rand = new Random (); + rand.NextBytes (src); + + return Convert.ToBase64String (src); + } + + internal static string CreateResponseKey (string base64Key) + { + var buffer = new StringBuilder (base64Key, 64); + buffer.Append (_guid); + SHA1 sha1 = new SHA1CryptoServiceProvider (); + var src = sha1.ComputeHash (Encoding.UTF8.GetBytes (buffer.ToString ())); + + return Convert.ToBase64String (src); + } + internal bool Ping (byte [] frameAsBytes, int timeOut) { return send (frameAsBytes) && @@ -1571,7 +1600,8 @@ namespace WebSocketSharp return; } - connect (); + if (connect ()) + open (); } /// @@ -1590,8 +1620,13 @@ namespace WebSocketSharp return; } - Action connector = connect; - connector.BeginInvoke (ar => connector.EndInvoke (ar), null); + Func connector = connect; + connector.BeginInvoke ( + ar => { + if (connector.EndInvoke (ar)) + open (); + }, + null); } ///