Fix due to the added RequestHandshake.cs

This commit is contained in:
sta
2012-08-31 13:52:34 +09:00
parent b3dc165d83
commit 7ea798e321
48 changed files with 317 additions and 256 deletions

View File

@@ -29,7 +29,7 @@
using System;
using WebSocketSharp.Frame;
namespace WebSocketSharp.Stream
namespace WebSocketSharp
{
public interface IWsStream : IDisposable
{

View File

@@ -0,0 +1,169 @@
#region MIT License
/**
* RequestHandshake.cs
*
* The MIT License
*
* Copyright (c) 2012 sta.blockhead
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#endregion
using System;
using System.Net;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Text;
namespace WebSocketSharp
{
public class RequestHandshake
{
private const string _crlf = "\r\n";
private RequestHandshake()
{
}
public RequestHandshake(string uri)
{
Method = "GET";
Uri = uri;
Version = "HTTP/1.1";
Headers = new NameValueCollection();
AddHeader("Upgrade", "websocket");
AddHeader("Connection", "Upgrade");
}
public NameValueCollection Headers { get; private set; }
public bool IsWebSocketRequest {
get {
if (Method != "GET")
return false;
if (Version != "HTTP/1.1")
return false;
if (!HeaderExists("Upgrade", "websocket"))
return false;
if (!HeaderExists("Connection", "Upgrade"))
return false;
if (!HeaderExists("Host"))
return false;
if (!HeaderExists("Sec-WebSocket-Key"))
return false;
if (!HeaderExists("Sec-WebSocket-Version"))
return false;
return true;
}
}
public string Method { get; private set; }
public string Uri { get; private set; }
public string Version { get; private set; }
public static RequestHandshake Parse(byte[] data)
{
var request = Encoding.UTF8.GetString(data)
.Replace("\r\n", "\n").Replace("\n\n", "\n").TrimEnd('\n')
.Split('\n');
return Parse(request);
}
public static RequestHandshake Parse(string[] request)
{
var requestLine = request[0].Split(' ');
if (requestLine.Length != 3)
throw new ArgumentException("Invalid request line.");
var headers = new WebHeaderCollection();
for (int i = 1; i < request.Length; i++)
headers.Add(request[i]);
return new RequestHandshake {
Headers = headers,
Method = requestLine[0],
Uri = requestLine[1],
Version = requestLine[2]
};
}
public void AddHeader(string name, string value)
{
Headers.Add(name, value);
}
public string[] GetHeaderValues(string name)
{
return Headers.GetValues(name);
}
public bool HeaderExists(string name)
{
return Headers[name] != null
? true
: false;
}
public bool HeaderExists(string name, string value)
{
var values = GetHeaderValues(name);
if (values == null)
return false;
foreach (string v in values)
if (String.Compare(value, v, true) == 0)
return true;
return false;
}
public byte[] ToBytes()
{
return Encoding.UTF8.GetBytes(ToString());
}
public override string ToString()
{
var buffer = new StringBuilder();
buffer.AppendFormat("{0} {1} {2}{3}", Method, Uri, Version, _crlf);
foreach (string key in Headers.AllKeys)
buffer.AppendFormat("{0}: {1}{2}", key, Headers[key], _crlf);
buffer.Append(_crlf);
return buffer.ToString();
}
}
}

View File

@@ -46,7 +46,6 @@ using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using WebSocketSharp.Frame;
using WebSocketSharp.Stream;
namespace WebSocketSharp
{
@@ -83,6 +82,71 @@ namespace WebSocketSharp
#endregion
#region Private Constructor
private WebSocket()
{
_binaryType = String.Empty;
_extensions = String.Empty;
_forClose = new Object();
_forSend = new Object();
_fragmentLen = 1024; // Max value is int.MaxValue - 14.
_protocol = String.Empty;
_readyState = WsState.CONNECTING;
_receivedPong = new AutoResetEvent(false);
_unTransmittedBuffer = new SynchronizedCollection<WsFrame>();
}
#endregion
#region Internal Constructor
internal WebSocket(Uri uri, TcpClient tcpClient)
: this()
{
_uri = uri;
_tcpClient = tcpClient;
_isClient = false;
}
#endregion
#region Public Constructors
public WebSocket(string url, params string[] protocols)
: this()
{
_uri = new Uri(url);
if (!isValidScheme(_uri))
{
var msg = "Unsupported WebSocket URI scheme: " + _uri.Scheme;
throw new ArgumentException(msg);
}
_protocols = protocols.ToString(", ");
_isClient = true;
}
public WebSocket(
string url,
EventHandler onOpen,
EventHandler<MessageEventArgs> onMessage,
EventHandler<ErrorEventArgs> onError,
EventHandler<CloseEventArgs> onClose,
params string[] protocols)
: this(url, protocols)
{
OnOpen = onOpen;
OnMessage = onMessage;
OnError = onError;
OnClose = onClose;
Connect();
}
#endregion
#region Properties
public string BinaryType
@@ -171,94 +235,23 @@ namespace WebSocketSharp
#endregion
#region Private Constructors
private WebSocket()
{
_binaryType = String.Empty;
_extensions = String.Empty;
_forClose = new Object();
_forSend = new Object();
_fragmentLen = 1024; // Max value is int.MaxValue - 14.
_protocol = String.Empty;
_readyState = WsState.CONNECTING;
_receivedPong = new AutoResetEvent(false);
_unTransmittedBuffer = new SynchronizedCollection<WsFrame>();
}
#endregion
#region Internal Constructors
internal WebSocket(Uri uri, TcpClient tcpClient)
: this()
{
_uri = uri;
_tcpClient = tcpClient;
_isClient = false;
}
#endregion
#region Public Constructors
public WebSocket(string url, params string[] protocols)
: this()
{
_uri = new Uri(url);
if (!isValidScheme(_uri))
{
var msg = "Unsupported WebSocket URI scheme: " + _uri.Scheme;
throw new ArgumentException(msg);
}
_protocols = protocols.ToString(", ");
_isClient = true;
}
public WebSocket(
string url,
EventHandler onOpen,
EventHandler<MessageEventArgs> onMessage,
EventHandler<ErrorEventArgs> onError,
EventHandler<CloseEventArgs> onClose,
params string[] protocols)
: this(url, protocols)
{
OnOpen = onOpen;
OnMessage = onMessage;
OnError = onError;
OnClose = onClose;
Connect();
}
#endregion
#region Private Methods
private void acceptHandshake()
{
string msg, response;
string[] request;
request = receiveOpeningHandshake();
var request = receiveOpeningHandshake();
#if DEBUG
Console.WriteLine("\nWS: Info@acceptHandshake: Opening handshake from client:\n");
foreach (string s in request)
{
Console.WriteLine("{0}", s);
}
Console.WriteLine("WS: Info@acceptHandshake: Opening handshake from client:\n");
Console.WriteLine(request.ToString());
#endif
string msg;
if (!isValidRequest(request, out msg))
{
throw new InvalidOperationException(msg);
}
response = createResponseHandshake();
var response = createResponseHandshake();
#if DEBUG
Console.WriteLine("\nWS: Info@acceptHandshake: Opening handshake from server:\n{0}", response);
Console.WriteLine("WS: Info@acceptHandshake: Opening handshake from server:\n");
Console.WriteLine(response);
#endif
sendResponseHandshake(response);
@@ -268,7 +261,7 @@ namespace WebSocketSharp
private void close(PayloadData data)
{
#if DEBUG
Console.WriteLine("\nWS: Info@close: Current thread IsBackground?: {0}", Thread.CurrentThread.IsBackground);
Console.WriteLine("WS: Info@close: Current thread IsBackground?: {0}", Thread.CurrentThread.IsBackground);
#endif
lock(_forClose)
{
@@ -362,14 +355,9 @@ namespace WebSocketSharp
if (port <= 0)
{
port = 80;
if (scheme == "wss")
{
port = 443;
}
else
{
port = 80;
}
}
_tcpClient = new TcpClient(host, port);
@@ -385,13 +373,12 @@ namespace WebSocketSharp
_sslStream = new SslStream(_netStream, false, validation);
_sslStream.AuthenticateAsClient(host);
_wsStream = new WsStream<SslStream>(_sslStream);
return;
}
else
{
_wsStream = new WsStream<NetworkStream>(_netStream);
}
_wsStream = new WsStream<NetworkStream>(_netStream);
}
private string createExpectedKey()
@@ -418,55 +405,34 @@ namespace WebSocketSharp
}
}
private string createOpeningHandshake()
private RequestHandshake createOpeningHandshake()
{
byte[] keySrc;
int port;
string crlf, host, origin, path;
string reqConnection, reqHost, reqMethod, reqOrigin, reqUpgrade;
string secWsKey, secWsProtocol, secWsVersion;
Random rand;
var path = _uri.PathAndQuery;
path = _uri.PathAndQuery;
host = _uri.DnsSafeHost;
port = ((IPEndPoint)_tcpClient.Client.RemoteEndPoint).Port;
var host = _uri.DnsSafeHost;
var port = ((IPEndPoint)_tcpClient.Client.RemoteEndPoint).Port;
if (port != 80)
{
host += ":" + port;
}
origin = "http://" + host;
var origin = "http://" + host;
keySrc = new byte[16];
rand = new Random();
var keySrc = new byte[16];
var rand = new Random();
rand.NextBytes(keySrc);
_base64key = Convert.ToBase64String(keySrc);
crlf = "\r\n";
reqMethod = String.Format("GET {0} HTTP/1.1{1}", path, crlf);
reqHost = String.Format("Host: {0}{1}", host, crlf);
reqUpgrade = String.Format("Upgrade: websocket{0}", crlf);
reqConnection = String.Format("Connection: Upgrade{0}", crlf);
reqOrigin = String.Format("Origin: {0}{1}", origin, crlf);
secWsKey = String.Format("Sec-WebSocket-Key: {0}{1}", _base64key, crlf);
secWsProtocol = _protocols != String.Empty
? String.Format("Sec-WebSocket-Protocol: {0}{1}", _protocols, crlf)
: _protocols;
var request = new RequestHandshake(path);
secWsVersion = String.Format("Sec-WebSocket-Version: {0}{1}", _version, crlf);
request.AddHeader("Host", host);
request.AddHeader("Origin", origin);
request.AddHeader("Sec-WebSocket-Key", _base64key);
if (!String.IsNullOrEmpty(_protocols))
request.AddHeader("Sec-WebSocket-Protocol", _protocols);
request.AddHeader("Sec-WebSocket-Version", _version);
return reqMethod +
reqHost +
reqUpgrade +
reqConnection +
secWsKey +
reqOrigin +
secWsProtocol +
secWsVersion +
crlf;
return request;
}
private string createResponseHandshake()
@@ -510,21 +476,19 @@ namespace WebSocketSharp
private void doHandshake()
{
string msg, request;
string[] response;
request = createOpeningHandshake();
var request = createOpeningHandshake();
#if DEBUG
Console.WriteLine("\nWS: Info@doHandshake: Opening handshake from client:\n{0}", request);
Console.WriteLine("WS: Info@doHandshake: Opening handshake from client:\n{0}", request);
#endif
response = sendOpeningHandshake(request);
var response = sendOpeningHandshake(request);
#if DEBUG
Console.WriteLine("\nWS: Info@doHandshake: Opening handshake from server:\n");
Console.WriteLine("WS: Info@doHandshake: Opening handshake from server:\n");
foreach (string s in response)
{
Console.WriteLine("{0}", s);
}
#endif
string msg;
if (!isValidResponse(response, out msg))
{
throw new InvalidOperationException(msg);
@@ -543,12 +507,13 @@ namespace WebSocketSharp
OnError.Emit(this, new ErrorEventArgs(message));
}
private bool isValidRequest(string[] request, out string message)
private bool isValidRequest(RequestHandshake request, out string message)
{
string reqConnection, reqHost, reqUpgrade, secWsVersion;
string[] reqRequest;
List<string> extensionList = new List<string>();
if (!request.IsWebSocketRequest)
{
message = "Not WebSocket request.";
return false;
}
Func<string, Func<string, string, string>> func = s =>
{
@@ -561,90 +526,28 @@ namespace WebSocketSharp
string expectedHost = _uri.DnsSafeHost;
int port = ((IPEndPoint)_tcpClient.Client.LocalEndPoint).Port;
if (port != 80)
{
expectedHost += ":" + port;
}
reqRequest = request[0].Split(' ');
if ("GET".NotEqualsDo(reqRequest[0], func("HTTP Method"), out message, false))
{
if (_uri.PathAndQuery.NotEqualsDo(request.Uri, func("Request URI"), out message, false))
return false;
}
if ("HTTP/1.1".NotEqualsDo(reqRequest[2], func("HTTP Version"), out message, false))
if (expectedHost.NotEqualsDo(request.GetHeaderValues("Host")[0], func("Host"), out message, false))
return false;
if (!request.HeaderExists("Sec-WebSocket-Version", _version))
{
message = "Unsupported Sec-WebSocket-Version.";
return false;
}
for (int i = 1; i < request.Length; i++)
{
if (request[i].Contains("Connection:"))
{
reqConnection = request[i].GetHeaderValue(":");
if ("Upgrade".NotEqualsDo(reqConnection, func("Connection"), out message, true))
{
return false;
}
}
else if (request[i].Contains("Host:"))
{
reqHost = request[i].GetHeaderValue(":");
if (expectedHost.NotEqualsDo(reqHost, func("Host"), out message, true))
{
return false;
}
}
else if (request[i].Contains("Origin:"))
{
continue;
}
else if (request[i].Contains("Upgrade:"))
{
reqUpgrade = request[i].GetHeaderValue(":");
if ("websocket".NotEqualsDo(reqUpgrade, func("Upgrade"), out message, true))
{
return false;
}
}
else if (request[i].Contains("Sec-WebSocket-Extensions:"))
{
extensionList.Add(request[i].GetHeaderValue(":"));
}
else if (request[i].Contains("Sec-WebSocket-Key:"))
{
_base64key = request[i].GetHeaderValue(":");
}
else if (request[i].Contains("Sec-WebSocket-Protocol:"))
{
_protocols = request[i].GetHeaderValue(":");
#if DEBUG
Console.WriteLine("WS: Info@isValidRequest: Sub protocol: {0}", _protocols);
#endif
}
else if (request[i].Contains("Sec-WebSocket-Version:"))
{
secWsVersion = request[i].GetHeaderValue(":");
if (_version.NotEqualsDo(secWsVersion, func("Sec-WebSocket-Version"), out message, true))
{
return false;
}
}
else
{
Console.WriteLine("WS: Info@isValidRequest: Unsupported request header line: {0}", request[i]);
}
}
_base64key = request.GetHeaderValues("Sec-WebSocket-Key")[0];
if (request.HeaderExists("Sec-WebSocket-Protocol"))
_protocols = request.Headers["Sec-WebSocket-Protocol"];
if (request.HeaderExists("Sec-WebSocket-Extensions"))
_extensions = request.Headers["Sec-WebSocket-Extensions"];
if (String.IsNullOrEmpty(_base64key))
{
message = "Sec-WebSocket-Key header field does not exist or the value isn't set.";
return false;
}
#if DEBUG
foreach (string s in extensionList)
{
Console.WriteLine("WS: Info@isValidRequest: Extensions: {0}", s);
}
#endif
message = String.Empty;
return true;
}
@@ -801,6 +704,22 @@ namespace WebSocketSharp
pong(payloadData);
}
private byte[] readHandshake()
{
var buffer = new List<byte>();
while (true)
{
if (_wsStream.ReadByte().EqualsAndSaveTo('\r', buffer) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', buffer) &&
_wsStream.ReadByte().EqualsAndSaveTo('\r', buffer) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', buffer))
break;
}
return buffer.ToArray();
}
private MessageEventArgs receive()
{
List<byte> dataBuffer;
@@ -923,24 +842,9 @@ namespace WebSocketSharp
return eventArgs;
}
private string[] receiveOpeningHandshake()
private RequestHandshake receiveOpeningHandshake()
{
var readData = new List<byte>();
while (true)
{
if (_wsStream.ReadByte().EqualsAndSaveTo('\r', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\r', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', readData))
{
break;
}
}
return Encoding.UTF8.GetString(readData.ToArray())
.Replace("\r\n", "\n").Replace("\n\n", "\n").TrimEnd('\n')
.Split('\n');
return RequestHandshake.Parse(readHandshake());
}
private bool send(WsFrame frame)
@@ -986,7 +890,7 @@ namespace WebSocketSharp
}
private void send<TStream>(Opcode opcode, TStream stream)
where TStream : System.IO.Stream
where TStream : Stream
{
lock(_forSend)
{
@@ -1013,7 +917,7 @@ namespace WebSocketSharp
}
private ulong sendFragmented<TStream>(Opcode opcode, TStream stream)
where TStream : System.IO.Stream
where TStream : Stream
{
WsFrame frame;
PayloadData payloadData;
@@ -1076,25 +980,13 @@ namespace WebSocketSharp
return readLen;
}
private string[] sendOpeningHandshake(string value)
private string[] sendOpeningHandshake(RequestHandshake request)
{
var readData = new List<byte>();
var buffer = Encoding.UTF8.GetBytes(value);
_wsStream.Write(request.ToBytes(), 0, request.ToBytes().Length);
_wsStream.Write(buffer, 0, buffer.Length);
var readData = readHandshake();
while (true)
{
if (_wsStream.ReadByte().EqualsAndSaveTo('\r', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\r', readData) &&
_wsStream.ReadByte().EqualsAndSaveTo('\n', readData))
{
break;
}
}
return Encoding.UTF8.GetString(readData.ToArray())
return Encoding.UTF8.GetString(readData)
.Replace("\r\n", "\n").Replace("\n\n", "\n").TrimEnd('\n')
.Split('\n');
}

View File

@@ -34,10 +34,10 @@ using System.Net.Sockets;
using System.Reflection;
using WebSocketSharp.Frame;
namespace WebSocketSharp.Stream
namespace WebSocketSharp
{
public class WsStream<TStream> : IWsStream
where TStream : System.IO.Stream
where TStream : Stream
{
private TStream _innerStream;
private Object _forRead;

View File

@@ -63,8 +63,6 @@
<Compile Include="CloseEventArgs.cs" />
<Compile Include="WsReceivedTooBigMessageException.cs" />
<Compile Include="ByteOrder.cs" />
<Compile Include="Stream\IWsStream.cs" />
<Compile Include="Stream\WsStream.cs" />
<Compile Include="Frame\WsFrame.cs" />
<Compile Include="Frame\CloseStatusCode.cs" />
<Compile Include="Frame\Fin.cs" />
@@ -76,10 +74,12 @@
<Compile Include="WebSocket.cs" />
<Compile Include="Server\WebSocketServer.cs" />
<Compile Include="Server\WebSocketService.cs" />
<Compile Include="IWsStream.cs" />
<Compile Include="WsStream.cs" />
<Compile Include="RequestHandshake.cs" />
</ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
<ItemGroup>
<Folder Include="Stream\" />
<Folder Include="Frame\" />
<Folder Include="Server\" />
</ItemGroup>

Binary file not shown.