Fix due to the modified WebSocket.cs

This commit is contained in:
sta 2012-10-04 15:04:21 +09:00
parent ae5d461a42
commit 94385ea2bc
64 changed files with 313 additions and 252 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -302,35 +302,45 @@ namespace WebSocketSharp {
if (length <= 0) if (length <= 0)
return new byte[]{}; return new byte[]{};
var buffer = new byte[length]; var buffer = new byte[length];
stream.Read(buffer, 0, length); var readLen = stream.Read(buffer, 0, length);
return buffer;
return readLen == length ? buffer : null;
} }
public static byte[] ReadBytes(this Stream stream, long length, int bufferLength) public static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
{ {
var count = length / bufferLength; var count = length / bufferLength;
var rem = length % bufferLength; var rem = length % bufferLength;
var readData = new List<byte>(); var readData = new List<byte>();
var readLen = 0; var readBuffer = new byte[bufferLength];
var buffer = new byte[bufferLength]; long readLen = 0;
var tmpLen = 0;
Action<byte[]> read = (buffer) =>
{
tmpLen = stream.Read(buffer, 0, buffer.Length);
if (tmpLen > 0)
{
readLen += tmpLen;
readData.AddRange(buffer.SubArray(0, tmpLen));
}
};
count.Times(() => count.Times(() =>
{ {
readLen = stream.Read(buffer, 0, bufferLength); read(readBuffer);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
}); });
if (rem > 0) if (rem > 0)
{ {
buffer = new byte[rem]; readBuffer = new byte[rem];
readLen = stream.Read(buffer, 0, (int)rem); read(readBuffer);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
} }
return readData.ToArray(); return readLen == length
? readData.ToArray()
: null;
} }
public static T[] SubArray<T>(this T[] array, int startIndex, int length) public static T[] SubArray<T>(this T[] array, int startIndex, int length)

View File

@ -28,6 +28,7 @@
using System; using System;
using System.IO; using System.IO;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
@ -35,7 +36,7 @@ namespace WebSocketSharp.Frame
{ {
public class WsFrame : IEnumerable<byte> public class WsFrame : IEnumerable<byte>
{ {
#region Private Static Fields #region Field
private static readonly int _readBufferLen; private static readonly int _readBufferLen;
@ -81,7 +82,7 @@ namespace WebSocketSharp.Frame
#endregion #endregion
#region Private Constructors #region Private Constructor
private WsFrame() private WsFrame()
{ {
@ -109,49 +110,159 @@ namespace WebSocketSharp.Frame
public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData) public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData)
: this() : this()
{ {
Fin = fin; Fin = fin;
Opcode = opcode; Opcode = opcode;
Masked = payloadData.Length != 0 ? mask : Mask.UNMASK;
ulong dataLength = payloadData.Length;
if (dataLength == 0)
{
Masked = Mask.UNMASK;
}
else
{
Masked = mask;
}
if (dataLength < 126)
{
PayloadLen = (byte)dataLength;
}
else if (dataLength < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)dataLength).ToBytes(ByteOrder.BIG);
}
else
{
PayloadLen = (byte)127;
ExtPayloadLen = dataLength.ToBytes(ByteOrder.BIG);
}
PayloadData = payloadData; PayloadData = payloadData;
if (Masked == Mask.MASK) init();
{
MaskingKey = new byte[4];
var rand = new Random();
rand.NextBytes(MaskingKey);
PayloadData.Mask(MaskingKey);
}
} }
#endregion #endregion
#region Public Static Methods #region Private Methods
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void init()
{
setPayloadLen(PayloadLength);
if (Masked == Mask.MASK)
maskPayloadData();
}
private void maskPayloadData()
{
var key = new byte[4];
var rand = new Random();
rand.NextBytes(key);
MaskingKey = key;
PayloadData.Mask(key);
}
private static void readExtPayloadLen(Stream stream, WsFrame frame)
{
var length = frame.PayloadLen <= 125
? 0
: frame.PayloadLen == 126 ? 2 : 8;
if (length > 0)
{
var extLength = stream.ReadBytes(length);
if (extLength == null)
throw new IOException();
frame.ExtPayloadLen = extLength;
}
}
private static WsFrame readHeader(Stream stream)
{
var header = stream.ReadBytes(2);
if (header == null)
return null;
// FIN
Fin fin = (header[0] & 0x80) == 0x80 ? Fin.FINAL : Fin.MORE;
// RSV1
Rsv rsv1 = (header[0] & 0x40) == 0x40 ? Rsv.ON : Rsv.OFF;
// RSV2
Rsv rsv2 = (header[0] & 0x20) == 0x20 ? Rsv.ON : Rsv.OFF;
// RSV3
Rsv rsv3 = (header[0] & 0x10) == 0x10 ? Rsv.ON : Rsv.OFF;
// Opcode
Opcode opcode = (Opcode)(header[0] & 0x0f);
// MASK
Mask masked = (header[1] & 0x80) == 0x80 ? Mask.MASK : Mask.UNMASK;
// Payload len
byte payloadLen = (byte)(header[1] & 0x7f);
return new WsFrame {
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen};
}
private static void readMaskingKey(Stream stream, WsFrame frame)
{
if (frame.Masked == Mask.MASK)
{
var maskingKey = stream.ReadBytes(4);
if (maskingKey == null)
throw new IOException();
frame.MaskingKey = maskingKey;
}
}
private static void readPayloadData(Stream stream, WsFrame frame, bool unmask)
{
ulong length = frame.PayloadLen <= 125
? frame.PayloadLen
: frame.PayloadLen == 126
? frame.ExtPayloadLen.To<ushort>(ByteOrder.BIG)
: frame.ExtPayloadLen.To<ulong>(ByteOrder.BIG);
var buffer = length <= (ulong)_readBufferLen
? stream.ReadBytes((int)length)
: stream.ReadBytes((long)length, _readBufferLen);
if (buffer == null)
throw new IOException();
PayloadData payloadData;
if (frame.Masked == Mask.MASK)
{
payloadData = new PayloadData(buffer, true);
if (unmask == true)
{
payloadData.Mask(frame.MaskingKey);
frame.Masked = Mask.UNMASK;
frame.MaskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer);
}
frame.PayloadData = payloadData;
}
private void setPayloadLen(ulong length)
{
if (length < 126)
{
PayloadLen = (byte)length;
return;
}
if (length < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)length).ToBytes(ByteOrder.BIG);
return;
}
PayloadLen = (byte)127;
ExtPayloadLen = length.ToBytes(ByteOrder.BIG);
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
yield return b;
}
public static WsFrame Parse(byte[] src) public static WsFrame Parse(byte[] src)
{ {
@ -166,183 +277,22 @@ namespace WebSocketSharp.Frame
} }
} }
public static WsFrame Parse<TStream>(TStream stream) public static WsFrame Parse(Stream stream)
where TStream : System.IO.Stream
{ {
return Parse(stream, true); return Parse(stream, true);
} }
public static WsFrame Parse<TStream>(TStream stream, bool unmask) public static WsFrame Parse(Stream stream, bool unmask)
where TStream : System.IO.Stream
{ {
Fin fin; var frame = readHeader(stream);
Rsv rsv1, rsv2, rsv3; if (frame == null)
Opcode opcode;
Mask masked;
byte payloadLen;
byte[] extPayloadLen = new byte[]{};
byte[] maskingKey = new byte[]{};
PayloadData payloadData;
byte[] buffer1, buffer2, buffer3;
int buffer1Len = 2;
int buffer2Len = 0;
ulong buffer3Len = 0;
int maskingKeyLen = 4;
int readLen = 0;
buffer1 = new byte[buffer1Len];
readLen = stream.Read(buffer1, 0, buffer1Len);
if (readLen < buffer1Len)
{
return null; return null;
}
// FIN readExtPayloadLen(stream, frame);
fin = (buffer1[0] & 0x80) == 0x80 readMaskingKey(stream, frame);
? Fin.FINAL readPayloadData(stream, frame, unmask);
: Fin.MORE;
// RSV1
rsv1 = (buffer1[0] & 0x40) == 0x40
? Rsv.ON
: Rsv.OFF;
// RSV2
rsv2 = (buffer1[0] & 0x20) == 0x20
? Rsv.ON
: Rsv.OFF;
// RSV3
rsv3 = (buffer1[0] & 0x10) == 0x10
? Rsv.ON
: Rsv.OFF;
// opcode
opcode = (Opcode)(buffer1[0] & 0x0f);
// MASK
masked = (buffer1[1] & 0x80) == 0x80
? Mask.MASK
: Mask.UNMASK;
// Payload len
payloadLen = (byte)(buffer1[1] & 0x7f);
// Extended payload length
if (payloadLen <= 125)
{
buffer3Len = payloadLen;
}
else if (payloadLen == 126)
{
buffer2Len = 2;
}
else
{
buffer2Len = 8;
}
if (buffer2Len > 0) return frame;
{
buffer2 = new byte[buffer2Len];
readLen = stream.Read(buffer2, 0, buffer2Len);
if (readLen < buffer2Len)
{
return null;
}
extPayloadLen = buffer2;
switch (buffer2Len)
{
case 2:
buffer3Len = extPayloadLen.To<ushort>(ByteOrder.BIG);
break;
case 8:
buffer3Len = extPayloadLen.To<ulong>(ByteOrder.BIG);
break;
}
}
if (buffer3Len > PayloadData.MaxLength)
{
throw new WsReceivedTooBigMessageException();
}
// Masking-key
if (masked == Mask.MASK)
{
maskingKey = new byte[maskingKeyLen];
readLen = stream.Read(maskingKey, 0, maskingKeyLen);
if (readLen < maskingKeyLen)
{
return null;
}
}
// Payload Data
if (buffer3Len == 0)
{
buffer3 = new byte[]{};
}
else if (buffer3Len <= (ulong)_readBufferLen)
{
buffer3 = new byte[buffer3Len];
readLen = stream.Read(buffer3, 0, (int)buffer3Len);
if (readLen < (int)buffer3Len)
{
return null;
}
}
else
{
buffer3 = stream.ReadBytes((long)buffer3Len, _readBufferLen);
if ((ulong)buffer3.LongLength < buffer3Len)
{
return null;
}
}
if (masked == Mask.MASK)
{
payloadData = new PayloadData(buffer3, true);
if (unmask == true)
{
payloadData.Mask(maskingKey);
masked = Mask.UNMASK;
maskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer3);
}
return new WsFrame
{
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen,
ExtPayloadLen = extPayloadLen,
MaskingKey = maskingKey,
PayloadData = payloadData
};
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
{
yield return b;
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
} }
public void Print() public void Print()
@ -462,33 +412,27 @@ namespace WebSocketSharp.Frame
public byte[] ToBytes() public byte[] ToBytes()
{ {
var bytes = new List<byte>(); var buffer = new List<byte>();
int first16 = (int)Fin; int header = (int)Fin;
first16 = (first16 << 1) + (int)Rsv1; header = (header << 1) + (int)Rsv1;
first16 = (first16 << 1) + (int)Rsv2; header = (header << 1) + (int)Rsv2;
first16 = (first16 << 1) + (int)Rsv3; header = (header << 1) + (int)Rsv3;
first16 = (first16 << 4) + (int)Opcode; header = (header << 4) + (int)Opcode;
first16 = (first16 << 1) + (int)Masked; header = (header << 1) + (int)Masked;
first16 = (first16 << 7) + (int)PayloadLen; header = (header << 7) + (int)PayloadLen;
bytes.AddRange(((ushort)first16).ToBytes(ByteOrder.BIG)); buffer.AddRange(((ushort)header).ToBytes(ByteOrder.BIG));
if (PayloadLen >= 126) if (PayloadLen >= 126)
{ buffer.AddRange(ExtPayloadLen);
bytes.AddRange(ExtPayloadLen);
}
if (Masked == Mask.MASK) if (Masked == Mask.MASK)
{ buffer.AddRange(MaskingKey);
bytes.AddRange(MaskingKey);
}
if (PayloadLen > 0) if (PayloadLen > 0)
{ buffer.AddRange(PayloadData.ToBytes());
bytes.AddRange(PayloadData.ToBytes());
}
return bytes.ToArray(); return buffer.ToArray();
} }
public override string ToString() public override string ToString()

View File

@ -31,13 +31,13 @@
using System; using System;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp.Net { namespace WebSocketSharp.Net {
@ -122,7 +122,7 @@ namespace WebSocketSharp.Net {
if (!secure) { if (!secure) {
stream = net_stream; stream = net_stream;
} else { } else {
var ssl_stream = new SslStream(net_stream); var ssl_stream = new SslStream(net_stream, false);
ssl_stream.AuthenticateAsServer(cert); ssl_stream.AuthenticateAsServer(cert);
stream = ssl_stream; stream = ssl_stream;
} }

View File

@ -0,0 +1,83 @@
#region MIT License
/**
* SslStream.cs
*
* The MIT License
*
* Copyright (c) 2012 sta.blockhead
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#endregion
using System;
using System.Net.Security;
using System.Net.Sockets;
namespace WebSocketSharp.Net.Security {
public class SslStream : System.Net.Security.SslStream
{
#region Constructors
public SslStream(NetworkStream innerStream)
: base(innerStream)
{
}
public SslStream(NetworkStream innerStream, bool leaveInnerStreamOpen)
: base(innerStream, leaveInnerStreamOpen)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback
) : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback,
LocalCertificateSelectionCallback userCertificateSelectionCallback
) : base(
innerStream,
leaveInnerStreamOpen,
userCertificateValidationCallback,
userCertificateSelectionCallback
)
{
}
#endregion
#region Property
public bool DataAvailable {
get {
return ((NetworkStream)InnerStream).DataAvailable;
}
}
#endregion
}
}

View File

@ -643,6 +643,19 @@ namespace WebSocketSharp {
return frame; return frame;
} }
private WsFrame readFrameWithTimeout()
{
if (!_wsStream.DataAvailable)
{
var timeout = 1 * 100;
Thread.Sleep(timeout);
if (!_wsStream.DataAvailable)
return null;
}
return readFrame();
}
private string[] readHandshake() private string[] readHandshake()
{ {
return _wsStream.ReadHandshake(); return _wsStream.ReadHandshake();
@ -650,7 +663,7 @@ namespace WebSocketSharp {
private MessageEventArgs receive() private MessageEventArgs receive()
{ {
var frame = readFrame(); var frame = _isClient ? readFrame() : readFrameWithTimeout();
if (frame == null) if (frame == null)
return null; return null;

View File

@ -31,11 +31,11 @@ using System.Collections.Generic;
using System.Configuration; using System.Configuration;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using System.Text; using System.Text;
using WebSocketSharp.Frame; using WebSocketSharp.Frame;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp namespace WebSocketSharp
{ {
@ -51,7 +51,7 @@ namespace WebSocketSharp
#endregion #endregion
#region Constructor #region Constructors
public WsStream(NetworkStream innerStream) public WsStream(NetworkStream innerStream)
{ {
@ -65,7 +65,16 @@ namespace WebSocketSharp
#endregion #endregion
#region Public Property #region Properties
public bool DataAvailable {
get {
if (_innerStreamType == typeof(SslStream))
return ((SslStream)_innerStream).DataAvailable;
return ((NetworkStream)_innerStream).DataAvailable;
}
}
public bool IsSecure { public bool IsSecure {
get { return _isSecure; } get { return _isSecure; }
@ -73,7 +82,7 @@ namespace WebSocketSharp
#endregion #endregion
#region Private Methods #region Private Method
private void init(Stream innerStream) private void init(Stream innerStream)
{ {
@ -98,7 +107,7 @@ namespace WebSocketSharp
if (port == 443) if (port == 443)
{ {
RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) => System.Net.Security.RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
{ {
// FIXME: Always returns true // FIXME: Always returns true
return true; return true;
@ -120,7 +129,7 @@ namespace WebSocketSharp
var port = ((IPEndPoint)client.Client.LocalEndPoint).Port; var port = ((IPEndPoint)client.Client.LocalEndPoint).Port;
if (port == 443) if (port == 443)
{ {
var sslStream = new SslStream(netStream); var sslStream = new SslStream(netStream, false);
var certPath = ConfigurationManager.AppSettings["ServerCertPath"]; var certPath = ConfigurationManager.AppSettings["ServerCertPath"];
sslStream.AuthenticateAsServer(new X509Certificate2(certPath)); sslStream.AuthenticateAsServer(new X509Certificate2(certPath));

View File

@ -110,6 +110,7 @@
<Compile Include="Server\IWebSocketServer.cs" /> <Compile Include="Server\IWebSocketServer.cs" />
<Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" /> <Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" />
<Compile Include="Server\WebSocketServerBase.cs" /> <Compile Include="Server\WebSocketServerBase.cs" />
<Compile Include="Net\Security\SslStream.cs" />
</ItemGroup> </ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /> <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
<ItemGroup> <ItemGroup>
@ -117,5 +118,6 @@
<Folder Include="Server\" /> <Folder Include="Server\" />
<Folder Include="Net\" /> <Folder Include="Net\" />
<Folder Include="Net\Sockets\" /> <Folder Include="Net\Sockets\" />
<Folder Include="Net\Security\" />
</ItemGroup> </ItemGroup>
</Project> </Project>

Binary file not shown.