Fix for subprotocols for client

This commit is contained in:
sta 2014-02-06 13:50:27 +09:00
parent 8567112200
commit 7e6b82306c
3 changed files with 104 additions and 27 deletions

View File

@ -73,8 +73,8 @@ namespace Example {
ThreadPool.QueueUserWorkItem(notifyMsg); ThreadPool.QueueUserWorkItem(notifyMsg);
using (var ws = new WebSocket("ws://echo.websocket.org", "echo")) using (var ws = new WebSocket("ws://echo.websocket.org"))
//using (var ws = new WebSocket("wss://echo.websocket.org", "echo")) //using (var ws = new WebSocket("wss://echo.websocket.org"))
//using (var ws = new WebSocket("ws://localhost:4649")) //using (var ws = new WebSocket("ws://localhost:4649"))
//using (var ws = new WebSocket("ws://localhost:4649/Echo")) //using (var ws = new WebSocket("ws://localhost:4649/Echo"))
//using (var ws = new WebSocket("wss://localhost:4649/Echo")) //using (var ws = new WebSocket("wss://localhost:4649/Echo"))

View File

@ -258,6 +258,16 @@ namespace WebSocketSharp
: null; : null;
} }
internal static string CheckIfValidProtocols (this string [] protocols)
{
return protocols.Contains (
protocol => protocol.Length == 0 || !protocol.IsToken ())
? "Contains an invalid value."
: protocols.ContainsTwice ()
? "Contains a value twice."
: null;
}
internal static string CheckIfValidSendData (this byte [] data) internal static string CheckIfValidSendData (this byte [] data)
{ {
return data == null return data == null
@ -334,6 +344,36 @@ namespace WebSocketSharp
: stream.ToByteArray (); : stream.ToByteArray ();
} }
internal static bool Contains<T> (
this IEnumerable<T> source, Func<T, bool> comparer)
{
foreach (T value in source)
if (comparer (value))
return true;
return false;
}
internal static bool ContainsTwice (this string [] values)
{
var len = values.Length;
Func<int, bool> contains = null;
contains = index => {
if (index < len - 1) {
for (var i = index + 1; i < len; i++)
if (values [i] == values [index])
return true;
return contains (++index);
}
return false;
};
return contains (0);
}
internal static T [] Copy<T> (this T [] src, long length) internal static T [] Copy<T> (this T [] src, long length)
{ {
var dest = new T [length]; var dest = new T [length];

View File

@ -38,7 +38,6 @@ using System.Collections.Generic;
using System.Collections.Specialized; using System.Collections.Specialized;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq;
using System.Net.Sockets; using System.Net.Sockets;
using System.Net.Security; using System.Net.Security;
using System.Security.Cryptography; using System.Security.Cryptography;
@ -89,7 +88,7 @@ namespace WebSocketSharp
private string _origin; private string _origin;
private bool _preAuth; private bool _preAuth;
private string _protocol; private string _protocol;
private string _protocols; private string [] _protocols;
private volatile WebSocketState _readyState; private volatile WebSocketState _readyState;
private AutoResetEvent _receivePong; private AutoResetEvent _receivePong;
private bool _secure; private bool _secure;
@ -148,10 +147,19 @@ namespace WebSocketSharp
/// </param> /// </param>
/// <param name="protocols"> /// <param name="protocols">
/// An array of <see cref="string"/> that contains the WebSocket subprotocols /// An array of <see cref="string"/> that contains the WebSocket subprotocols
/// if any. /// if any. Each value of <paramref name="protocols"/> must be a token defined
/// in <see href="http://tools.ietf.org/html/rfc2616#section-2.2">RFC 2616</see>.
/// </param> /// </param>
/// <exception cref="ArgumentException"> /// <exception cref="ArgumentException">
/// <para>
/// <paramref name="url"/> is invalid. /// <paramref name="url"/> is invalid.
/// </para>
/// <para>
/// -or-
/// </para>
/// <para>
/// <paramref name="protocols"/> is invalid.
/// </para>
/// </exception> /// </exception>
/// <exception cref="ArgumentNullException"> /// <exception cref="ArgumentNullException">
/// <paramref name="url"/> is <see langword="null"/>. /// <paramref name="url"/> is <see langword="null"/>.
@ -165,7 +173,13 @@ namespace WebSocketSharp
if (!url.TryCreateWebSocketUri (out _uri, out msg)) if (!url.TryCreateWebSocketUri (out _uri, out msg))
throw new ArgumentException (msg, "url"); throw new ArgumentException (msg, "url");
_protocols = protocols.ToString (", "); if (protocols != null && protocols.Length > 0) {
msg = protocols.CheckIfValidProtocols ();
if (msg != null)
throw new ArgumentException (msg, "protocols");
_protocols = protocols;
}
_base64Key = CreateBase64Key (); _base64Key = CreateBase64Key ();
_client = true; _client = true;
@ -677,19 +691,20 @@ namespace WebSocketSharp
private string checkIfValidHandshakeResponse (HandshakeResponse response) private string checkIfValidHandshakeResponse (HandshakeResponse response)
{ {
var headers = response.Headers; var headers = response.Headers;
string accept, version;
return response.IsUnauthorized return response.IsUnauthorized
? String.Format ( ? String.Format (
"An HTTP {0} authorization is required.", "HTTP {0} authorization is required.",
response.AuthChallenge.Scheme) response.AuthChallenge.Scheme)
: !response.IsWebSocketResponse : !response.IsWebSocketResponse
? "Not WebSocket connection response to the connection request." ? "Not WebSocket connection response."
: (accept = headers ["Sec-WebSocket-Accept"]) == null || : !validateSecWebSocketAcceptHeader (
accept != CreateResponseKey (_base64Key) headers ["Sec-WebSocket-Accept"])
? "Invalid Sec-WebSocket-Accept header." ? "Invalid Sec-WebSocket-Accept header."
: (version = headers ["Sec-WebSocket-Version"]) != null && : !validateSecWebSocketProtocolHeader (
version != _version headers ["Sec-WebSocket-Protocol"])
? "Invalid Sec-WebSocket-Protocol header."
: !validateSecWebSocketVersionHeader (
headers ["Sec-WebSocket-Version"])
? "Invalid Sec-WebSocket-Version header." ? "Invalid Sec-WebSocket-Version header."
: null; : null;
} }
@ -894,8 +909,8 @@ namespace WebSocketSharp
headers ["Sec-WebSocket-Key"] = _base64Key; headers ["Sec-WebSocket-Key"] = _base64Key;
if (!_protocols.IsNullOrEmpty ()) if (_protocols != null)
headers ["Sec-WebSocket-Protocol"] = _protocols; headers ["Sec-WebSocket-Protocol"] = _protocols.ToString (", ");
var extensions = createExtensionsRequest (); var extensions = createExtensionsRequest ();
if (extensions.Length > 0) if (extensions.Length > 0)
@ -955,21 +970,17 @@ namespace WebSocketSharp
{ {
setClientStream (); setClientStream ();
var res = sendHandshakeRequest (); var res = sendHandshakeRequest ();
var err = checkIfValidHandshakeResponse (res); var msg = checkIfValidHandshakeResponse (res);
if (err != null) { if (msg != null) {
_logger.Error (err); _logger.Error (msg);
var msg = "An error has occurred while connecting."; msg = "An error has occurred while connecting.";
error (msg); error (msg);
close (CloseStatusCode.ABNORMAL, msg, false); close (CloseStatusCode.ABNORMAL, msg, false);
return false; return false;
} }
var protocol = res.Headers ["Sec-WebSocket-Protocol"];
if (!protocol.IsNullOrEmpty ())
_protocol = protocol;
processRespondedExtensions (res.Headers ["Sec-WebSocket-Extensions"]); processRespondedExtensions (res.Headers ["Sec-WebSocket-Extensions"]);
var cookies = res.Cookies; var cookies = res.Cookies;
@ -1364,6 +1375,32 @@ namespace WebSocketSharp
host == expected; host == expected;
} }
// As client
private bool validateSecWebSocketAcceptHeader (string value)
{
return value != null && value == CreateResponseKey (_base64Key);
}
// As client
private bool validateSecWebSocketProtocolHeader (string value)
{
if (value == null)
return _protocols == null;
if (_protocols == null ||
!_protocols.Contains (protocol => protocol == value))
return false;
_protocol = value;
return true;
}
// As client
private bool validateSecWebSocketVersionHeader (string value)
{
return value == null || value == _version;
}
#endregion #endregion
#region Internal Methods #region Internal Methods