diff --git a/websocket-sharp/Net/WebSockets/TcpListenerWebSocketContext.cs b/websocket-sharp/Net/WebSockets/TcpListenerWebSocketContext.cs index cb097480..e6e9c7cd 100644 --- a/websocket-sharp/Net/WebSockets/TcpListenerWebSocketContext.cs +++ b/websocket-sharp/Net/WebSockets/TcpListenerWebSocketContext.cs @@ -66,7 +66,7 @@ namespace WebSocketSharp.Net.WebSockets _client = client; _secure = secure; _stream = WebSocketStream.CreateServerStream (client, secure, cert); - _request = _stream.ReadHttp (HttpRequest.Parse, 90000); + _request = _stream.ReadHttpRequest (90000); _uri = HttpUtility.CreateRequestUrl ( _request.RequestUri, _request.Headers ["Host"], _request.IsWebSocketRequest, secure); @@ -328,8 +328,8 @@ namespace WebSocketSharp.Net.WebSockets { var res = new HttpResponse (HttpStatusCode.Unauthorized); res.Headers ["WWW-Authenticate"] = challenge; - _stream.WriteHandshake (res); - _request = _stream.ReadHttp (HttpRequest.Parse, 15000); + _stream.WriteBytes (res.ToByteArray ()); + _request = _stream.ReadHttpRequest (15000); } internal void SetUser ( diff --git a/websocket-sharp/WebSocket.cs b/websocket-sharp/WebSocket.cs index 7989b51d..44a600a8 100644 --- a/websocket-sharp/WebSocket.cs +++ b/websocket-sharp/WebSocket.cs @@ -91,6 +91,8 @@ namespace WebSocketSharp private bool _preAuth; private string _protocol; private string [] _protocols; + private NetworkCredential _proxyCredentials; + private Uri _proxyUri; private volatile WebSocketState _readyState; private AutoResetEvent _receivePong; private bool _secure; @@ -776,7 +778,7 @@ namespace WebSocketSharp private bool closeHandshake (byte [] frame, int timeout, Action release) { - var sent = frame != null && _stream.Write (frame); + var sent = frame != null && _stream.WriteBytes (frame); var received = timeout == 0 || (sent && _exitReceiving != null && _exitReceiving.WaitOne (timeout)); @@ -813,7 +815,7 @@ namespace WebSocketSharp private bool concatenateFragmentsInto (Stream dest) { while (true) { - var frame = _stream.ReadFrame (); + var frame = _stream.ReadWebSocketFrame (); if (frame.IsFinal) { // FINAL @@ -1046,7 +1048,7 @@ namespace WebSocketSharp // As client private HttpResponse receiveHandshakeResponse () { - var res = _stream.ReadHandshakeResponse (); + var res = _stream.ReadHttpResponse (90000); _logger.Debug ("A response to this WebSocket connection request:\n" + res.ToString ()); return res; @@ -1060,7 +1062,7 @@ namespace WebSocketSharp return false; } - return _stream.Write (frame); + return _stream.WriteBytes (frame); } } @@ -1070,7 +1072,7 @@ namespace WebSocketSharp _logger.Debug ( String.Format ("A WebSocket connection request to {0}:\n{1}", _uri, request)); - _stream.WriteHandshake (request); + _stream.WriteBytes (request.ToByteArray ()); } // As server @@ -1079,7 +1081,7 @@ namespace WebSocketSharp _logger.Debug ( "A response to the WebSocket connection request:\n" + response.ToString ()); - return _stream.WriteHandshake (response); + return _stream.WriteBytes (response.ToByteArray ()); } private bool send (WebSocketFrame frame) @@ -1090,7 +1092,7 @@ namespace WebSocketSharp return false; } - return _stream.Write (frame.ToByteArray ()); + return _stream.WriteBytes (frame.ToByteArray ()); } } @@ -1260,12 +1262,8 @@ namespace WebSocketSharp // As client private void setClientStream () { - var host = _uri.DnsSafeHost; - var port = _uri.Port; - - _tcpClient = new TcpClient (host, port); _stream = WebSocketStream.CreateClientStream ( - _tcpClient, _secure, host, _certValidationCallback); + _uri, _proxyUri, _proxyCredentials, _secure, _certValidationCallback, out _tcpClient); } private void startReceiving () @@ -1277,7 +1275,7 @@ namespace WebSocketSharp _receivePong = new AutoResetEvent (false); Action receive = null; - receive = () => _stream.ReadFrameAsync ( + receive = () => _stream.ReadWebSocketFrameAsync ( frame => { if (acceptFrame (frame) && _readyState != WebSocketState.Closed) { receive (); @@ -1484,7 +1482,7 @@ namespace WebSocketSharp cache.Add (_compression, cached); } - _stream.Write (cached); + _stream.WriteBytes (cached); } catch (Exception ex) { _logger.Fatal (ex.ToString ()); diff --git a/websocket-sharp/WebSocketStream.cs b/websocket-sharp/WebSocketStream.cs index 063d7fc5..e478a01d 100644 --- a/websocket-sharp/WebSocketStream.cs +++ b/websocket-sharp/WebSocketStream.cs @@ -42,7 +42,7 @@ namespace WebSocketSharp { #region Private Const Fields - private const int _httpHeadersLimitLen = 8192; + private const int _headersMaxLength = 8192; #endregion @@ -65,20 +65,6 @@ namespace WebSocketSharp #endregion - #region Public Constructors - - public WebSocketStream (NetworkStream innerStream) - : this (innerStream, false) - { - } - - public WebSocketStream (SslStream innerStream) - : this (innerStream, true) - { - } - - #endregion - #region Public Properties public bool DataAvailable { @@ -99,7 +85,7 @@ namespace WebSocketSharp #region Private Methods - private static byte [] readHttpEntityBody (Stream stream, string length) + private static byte[] readEntityBody (Stream stream, string length) { long len; if (!Int64.TryParse (length, out len)) @@ -115,17 +101,17 @@ namespace WebSocketSharp : null; } - private static string [] readHttpHeaders (Stream stream) + private static string[] readHeaders (Stream stream, int maxLength) { var buff = new List (); - var count = 0; + var cnt = 0; Action add = i => { buff.Add ((byte) i); - count++; + cnt++; }; var read = false; - while (count < _httpHeadersLimitLen) { + while (cnt < maxLength) { if (stream.ReadByte ().EqualsWith ('\r', add) && stream.ReadByte ().EqualsWith ('\n', add) && stream.ReadByte ().EqualsWith ('\r', add) && @@ -137,27 +123,23 @@ namespace WebSocketSharp if (!read) throw new WebSocketException ( - "The header part of a HTTP data is greater than the limit length."); + "The header part of a HTTP data is greater than the max length."); var crlf = "\r\n"; return Encoding.UTF8.GetString (buff.ToArray ()) .Replace (crlf + " ", " ") .Replace (crlf + "\t", " ") - .Split (new [] { crlf }, StringSplitOptions.RemoveEmptyEntries); + .Split (new[] { crlf }, StringSplitOptions.RemoveEmptyEntries); } - #endregion - - #region Internal Methods - - internal T ReadHttp (Func parser, int millisecondsTimeout) + private static T readHttp (Stream stream, Func parser, int millisecondsTimeout) where T : HttpBase { var timeout = false; var timer = new Timer ( state => { timeout = true; - _innerStream.Close (); + stream.Close (); }, null, millisecondsTimeout, @@ -166,10 +148,10 @@ namespace WebSocketSharp T http = null; Exception exception = null; try { - http = parser (readHttpHeaders (_innerStream)); - var contentLen = http.Headers ["Content-Length"]; + http = parser (readHeaders (stream, _headersMaxLength)); + var contentLen = http.Headers["Content-Length"]; if (contentLen != null && contentLen.Length > 0) - http.EntityBodyData = readHttpEntityBody (_innerStream, contentLen); + http.EntityBodyData = readEntityBody (stream, contentLen); } catch (Exception ex) { exception = ex; @@ -191,7 +173,104 @@ namespace WebSocketSharp return http; } - internal bool Write (byte [] data) + private static HttpResponse sendHttpRequest ( + Stream stream, HttpRequest request, int millisecondsTimeout) + { + var buff = request.ToByteArray (); + stream.Write (buff, 0, buff.Length); + + return readHttp (stream, HttpResponse.Parse, millisecondsTimeout); + } + + #endregion + + #region Internal Methods + + internal static WebSocketStream CreateClientStream ( + Uri targetUri, + Uri proxyUri, + NetworkCredential proxyCredentials, + bool secure, + System.Net.Security.RemoteCertificateValidationCallback validationCallback, + out TcpClient tcpClient) + { + var proxy = proxyUri != null; + tcpClient = proxy + ? new TcpClient (proxyUri.DnsSafeHost, proxyUri.Port) + : new TcpClient (targetUri.DnsSafeHost, targetUri.Port); + + var netStream = tcpClient.GetStream (); + if (proxy) { + var req = HttpRequest.CreateConnectRequest (targetUri); + var res = sendHttpRequest (netStream, req, 90000); + if (res.IsProxyAuthenticationRequired) { + var authChal = res.ProxyAuthenticationChallenge; + if (authChal != null && proxyCredentials != null) { + var authRes = new AuthenticationResponse (authChal, proxyCredentials, 0); + req.Headers["Proxy-Authorization"] = authRes.ToString (); + res = sendHttpRequest (netStream, req, 15000); + } + + if (res.IsProxyAuthenticationRequired) + throw new WebSocketException ("Proxy authentication is required."); + } + + var code = res.StatusCode; + if (code.Length != 3 || code[0] != '2') + throw new WebSocketException ( + String.Format ( + "The proxy has failed a connection to the requested host and port. ({0})", code)); + } + + if (secure) { + var sslStream = new SslStream ( + netStream, + false, + validationCallback ?? ((sender, certificate, chain, sslPolicyErrors) => true)); + + sslStream.AuthenticateAsClient (targetUri.DnsSafeHost); + return new WebSocketStream (sslStream, secure); + } + + return new WebSocketStream (netStream, secure); + } + + internal static WebSocketStream CreateServerStream ( + TcpClient tcpClient, bool secure, X509Certificate certificate) + { + var netStream = tcpClient.GetStream (); + if (secure) { + var sslStream = new SslStream (netStream, false); + sslStream.AuthenticateAsServer (certificate); + + return new WebSocketStream (sslStream, secure); + } + + return new WebSocketStream (netStream, secure); + } + + internal HttpRequest ReadHttpRequest (int millisecondsTimeout) + { + return readHttp (_innerStream, HttpRequest.Parse, millisecondsTimeout); + } + + internal HttpResponse ReadHttpResponse (int millisecondsTimeout) + { + return readHttp (_innerStream, HttpResponse.Parse, millisecondsTimeout); + } + + internal WebSocketFrame ReadWebSocketFrame () + { + return WebSocketFrame.Parse (_innerStream, true); + } + + internal void ReadWebSocketFrameAsync ( + Action completed, Action error) + { + WebSocketFrame.ParseAsync (_innerStream, true, completed, error); + } + + internal bool WriteBytes (byte[] data) { lock (_forWrite) { try { @@ -213,75 +292,11 @@ namespace WebSocketSharp _innerStream.Close (); } - public static WebSocketStream CreateClientStream ( - TcpClient client, - bool secure, - string host, - System.Net.Security.RemoteCertificateValidationCallback validationCallback) - { - var netStream = client.GetStream (); - if (secure) { - if (validationCallback == null) - validationCallback = (sender, certificate, chain, sslPolicyErrors) => true; - - var sslStream = new SslStream (netStream, false, validationCallback); - sslStream.AuthenticateAsClient (host); - - return new WebSocketStream (sslStream); - } - - return new WebSocketStream (netStream); - } - - public static WebSocketStream CreateServerStream ( - TcpClient client, bool secure, X509Certificate cert) - { - var netStream = client.GetStream (); - if (secure) { - var sslStream = new SslStream (netStream, false); - sslStream.AuthenticateAsServer (cert); - - return new WebSocketStream (sslStream); - } - - return new WebSocketStream (netStream); - } - public void Dispose () { _innerStream.Dispose (); } - public WebSocketFrame ReadFrame () - { - return WebSocketFrame.Parse (_innerStream, true); - } - - public void ReadFrameAsync (Action completed, Action error) - { - WebSocketFrame.ParseAsync (_innerStream, true, completed, error); - } - - public HttpRequest ReadHandshakeRequest () - { - return ReadHttp (HttpRequest.Parse, 90000); - } - - public HttpResponse ReadHandshakeResponse () - { - return ReadHttp (HttpResponse.Parse, 90000); - } - - public bool WriteFrame (WebSocketFrame frame) - { - return Write (frame.ToByteArray ()); - } - - public bool WriteHandshake (HttpBase handshake) - { - return Write (handshake.ToByteArray ()); - } - #endregion } }