Fix due to the modified WebSocket.cs

This commit is contained in:
sta 2012-10-04 15:04:21 +09:00
parent ae5d461a42
commit 94385ea2bc
64 changed files with 313 additions and 252 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -302,35 +302,45 @@ namespace WebSocketSharp {
if (length <= 0)
return new byte[]{};
var buffer = new byte[length];
stream.Read(buffer, 0, length);
return buffer;
var buffer = new byte[length];
var readLen = stream.Read(buffer, 0, length);
return readLen == length ? buffer : null;
}
public static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
{
var count = length / bufferLength;
var rem = length % bufferLength;
var readData = new List<byte>();
var readLen = 0;
var buffer = new byte[bufferLength];
var count = length / bufferLength;
var rem = length % bufferLength;
var readData = new List<byte>();
var readBuffer = new byte[bufferLength];
long readLen = 0;
var tmpLen = 0;
Action<byte[]> read = (buffer) =>
{
tmpLen = stream.Read(buffer, 0, buffer.Length);
if (tmpLen > 0)
{
readLen += tmpLen;
readData.AddRange(buffer.SubArray(0, tmpLen));
}
};
count.Times(() =>
{
readLen = stream.Read(buffer, 0, bufferLength);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
read(readBuffer);
});
if (rem > 0)
{
buffer = new byte[rem];
readLen = stream.Read(buffer, 0, (int)rem);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
readBuffer = new byte[rem];
read(readBuffer);
}
return readData.ToArray();
return readLen == length
? readData.ToArray()
: null;
}
public static T[] SubArray<T>(this T[] array, int startIndex, int length)

View File

@ -28,6 +28,7 @@
using System;
using System.IO;
using System.Collections;
using System.Collections.Generic;
using System.Text;
@ -35,7 +36,7 @@ namespace WebSocketSharp.Frame
{
public class WsFrame : IEnumerable<byte>
{
#region Private Static Fields
#region Field
private static readonly int _readBufferLen;
@ -81,7 +82,7 @@ namespace WebSocketSharp.Frame
#endregion
#region Private Constructors
#region Private Constructor
private WsFrame()
{
@ -109,49 +110,159 @@ namespace WebSocketSharp.Frame
public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData)
: this()
{
Fin = fin;
Opcode = opcode;
ulong dataLength = payloadData.Length;
if (dataLength == 0)
{
Masked = Mask.UNMASK;
}
else
{
Masked = mask;
}
if (dataLength < 126)
{
PayloadLen = (byte)dataLength;
}
else if (dataLength < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)dataLength).ToBytes(ByteOrder.BIG);
}
else
{
PayloadLen = (byte)127;
ExtPayloadLen = dataLength.ToBytes(ByteOrder.BIG);
}
Fin = fin;
Opcode = opcode;
Masked = payloadData.Length != 0 ? mask : Mask.UNMASK;
PayloadData = payloadData;
if (Masked == Mask.MASK)
{
MaskingKey = new byte[4];
var rand = new Random();
rand.NextBytes(MaskingKey);
PayloadData.Mask(MaskingKey);
}
init();
}
#endregion
#region Public Static Methods
#region Private Methods
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void init()
{
setPayloadLen(PayloadLength);
if (Masked == Mask.MASK)
maskPayloadData();
}
private void maskPayloadData()
{
var key = new byte[4];
var rand = new Random();
rand.NextBytes(key);
MaskingKey = key;
PayloadData.Mask(key);
}
private static void readExtPayloadLen(Stream stream, WsFrame frame)
{
var length = frame.PayloadLen <= 125
? 0
: frame.PayloadLen == 126 ? 2 : 8;
if (length > 0)
{
var extLength = stream.ReadBytes(length);
if (extLength == null)
throw new IOException();
frame.ExtPayloadLen = extLength;
}
}
private static WsFrame readHeader(Stream stream)
{
var header = stream.ReadBytes(2);
if (header == null)
return null;
// FIN
Fin fin = (header[0] & 0x80) == 0x80 ? Fin.FINAL : Fin.MORE;
// RSV1
Rsv rsv1 = (header[0] & 0x40) == 0x40 ? Rsv.ON : Rsv.OFF;
// RSV2
Rsv rsv2 = (header[0] & 0x20) == 0x20 ? Rsv.ON : Rsv.OFF;
// RSV3
Rsv rsv3 = (header[0] & 0x10) == 0x10 ? Rsv.ON : Rsv.OFF;
// Opcode
Opcode opcode = (Opcode)(header[0] & 0x0f);
// MASK
Mask masked = (header[1] & 0x80) == 0x80 ? Mask.MASK : Mask.UNMASK;
// Payload len
byte payloadLen = (byte)(header[1] & 0x7f);
return new WsFrame {
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen};
}
private static void readMaskingKey(Stream stream, WsFrame frame)
{
if (frame.Masked == Mask.MASK)
{
var maskingKey = stream.ReadBytes(4);
if (maskingKey == null)
throw new IOException();
frame.MaskingKey = maskingKey;
}
}
private static void readPayloadData(Stream stream, WsFrame frame, bool unmask)
{
ulong length = frame.PayloadLen <= 125
? frame.PayloadLen
: frame.PayloadLen == 126
? frame.ExtPayloadLen.To<ushort>(ByteOrder.BIG)
: frame.ExtPayloadLen.To<ulong>(ByteOrder.BIG);
var buffer = length <= (ulong)_readBufferLen
? stream.ReadBytes((int)length)
: stream.ReadBytes((long)length, _readBufferLen);
if (buffer == null)
throw new IOException();
PayloadData payloadData;
if (frame.Masked == Mask.MASK)
{
payloadData = new PayloadData(buffer, true);
if (unmask == true)
{
payloadData.Mask(frame.MaskingKey);
frame.Masked = Mask.UNMASK;
frame.MaskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer);
}
frame.PayloadData = payloadData;
}
private void setPayloadLen(ulong length)
{
if (length < 126)
{
PayloadLen = (byte)length;
return;
}
if (length < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)length).ToBytes(ByteOrder.BIG);
return;
}
PayloadLen = (byte)127;
ExtPayloadLen = length.ToBytes(ByteOrder.BIG);
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
yield return b;
}
public static WsFrame Parse(byte[] src)
{
@ -166,183 +277,22 @@ namespace WebSocketSharp.Frame
}
}
public static WsFrame Parse<TStream>(TStream stream)
where TStream : System.IO.Stream
public static WsFrame Parse(Stream stream)
{
return Parse(stream, true);
}
public static WsFrame Parse<TStream>(TStream stream, bool unmask)
where TStream : System.IO.Stream
public static WsFrame Parse(Stream stream, bool unmask)
{
Fin fin;
Rsv rsv1, rsv2, rsv3;
Opcode opcode;
Mask masked;
byte payloadLen;
byte[] extPayloadLen = new byte[]{};
byte[] maskingKey = new byte[]{};
PayloadData payloadData;
byte[] buffer1, buffer2, buffer3;
int buffer1Len = 2;
int buffer2Len = 0;
ulong buffer3Len = 0;
int maskingKeyLen = 4;
int readLen = 0;
buffer1 = new byte[buffer1Len];
readLen = stream.Read(buffer1, 0, buffer1Len);
if (readLen < buffer1Len)
{
var frame = readHeader(stream);
if (frame == null)
return null;
}
// FIN
fin = (buffer1[0] & 0x80) == 0x80
? Fin.FINAL
: Fin.MORE;
// RSV1
rsv1 = (buffer1[0] & 0x40) == 0x40
? Rsv.ON
: Rsv.OFF;
// RSV2
rsv2 = (buffer1[0] & 0x20) == 0x20
? Rsv.ON
: Rsv.OFF;
// RSV3
rsv3 = (buffer1[0] & 0x10) == 0x10
? Rsv.ON
: Rsv.OFF;
// opcode
opcode = (Opcode)(buffer1[0] & 0x0f);
// MASK
masked = (buffer1[1] & 0x80) == 0x80
? Mask.MASK
: Mask.UNMASK;
// Payload len
payloadLen = (byte)(buffer1[1] & 0x7f);
// Extended payload length
if (payloadLen <= 125)
{
buffer3Len = payloadLen;
}
else if (payloadLen == 126)
{
buffer2Len = 2;
}
else
{
buffer2Len = 8;
}
readExtPayloadLen(stream, frame);
readMaskingKey(stream, frame);
readPayloadData(stream, frame, unmask);
if (buffer2Len > 0)
{
buffer2 = new byte[buffer2Len];
readLen = stream.Read(buffer2, 0, buffer2Len);
if (readLen < buffer2Len)
{
return null;
}
extPayloadLen = buffer2;
switch (buffer2Len)
{
case 2:
buffer3Len = extPayloadLen.To<ushort>(ByteOrder.BIG);
break;
case 8:
buffer3Len = extPayloadLen.To<ulong>(ByteOrder.BIG);
break;
}
}
if (buffer3Len > PayloadData.MaxLength)
{
throw new WsReceivedTooBigMessageException();
}
// Masking-key
if (masked == Mask.MASK)
{
maskingKey = new byte[maskingKeyLen];
readLen = stream.Read(maskingKey, 0, maskingKeyLen);
if (readLen < maskingKeyLen)
{
return null;
}
}
// Payload Data
if (buffer3Len == 0)
{
buffer3 = new byte[]{};
}
else if (buffer3Len <= (ulong)_readBufferLen)
{
buffer3 = new byte[buffer3Len];
readLen = stream.Read(buffer3, 0, (int)buffer3Len);
if (readLen < (int)buffer3Len)
{
return null;
}
}
else
{
buffer3 = stream.ReadBytes((long)buffer3Len, _readBufferLen);
if ((ulong)buffer3.LongLength < buffer3Len)
{
return null;
}
}
if (masked == Mask.MASK)
{
payloadData = new PayloadData(buffer3, true);
if (unmask == true)
{
payloadData.Mask(maskingKey);
masked = Mask.UNMASK;
maskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer3);
}
return new WsFrame
{
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen,
ExtPayloadLen = extPayloadLen,
MaskingKey = maskingKey,
PayloadData = payloadData
};
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
{
yield return b;
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
return frame;
}
public void Print()
@ -462,33 +412,27 @@ namespace WebSocketSharp.Frame
public byte[] ToBytes()
{
var bytes = new List<byte>();
var buffer = new List<byte>();
int first16 = (int)Fin;
first16 = (first16 << 1) + (int)Rsv1;
first16 = (first16 << 1) + (int)Rsv2;
first16 = (first16 << 1) + (int)Rsv3;
first16 = (first16 << 4) + (int)Opcode;
first16 = (first16 << 1) + (int)Masked;
first16 = (first16 << 7) + (int)PayloadLen;
bytes.AddRange(((ushort)first16).ToBytes(ByteOrder.BIG));
int header = (int)Fin;
header = (header << 1) + (int)Rsv1;
header = (header << 1) + (int)Rsv2;
header = (header << 1) + (int)Rsv3;
header = (header << 4) + (int)Opcode;
header = (header << 1) + (int)Masked;
header = (header << 7) + (int)PayloadLen;
buffer.AddRange(((ushort)header).ToBytes(ByteOrder.BIG));
if (PayloadLen >= 126)
{
bytes.AddRange(ExtPayloadLen);
}
buffer.AddRange(ExtPayloadLen);
if (Masked == Mask.MASK)
{
bytes.AddRange(MaskingKey);
}
buffer.AddRange(MaskingKey);
if (PayloadLen > 0)
{
bytes.AddRange(PayloadData.ToBytes());
}
buffer.AddRange(PayloadData.ToBytes());
return bytes.ToArray();
return buffer.ToArray();
}
public override string ToString()

View File

@ -31,13 +31,13 @@
using System;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Reflection;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp.Net {
@ -122,7 +122,7 @@ namespace WebSocketSharp.Net {
if (!secure) {
stream = net_stream;
} else {
var ssl_stream = new SslStream(net_stream);
var ssl_stream = new SslStream(net_stream, false);
ssl_stream.AuthenticateAsServer(cert);
stream = ssl_stream;
}

View File

@ -0,0 +1,83 @@
#region MIT License
/**
* SslStream.cs
*
* The MIT License
*
* Copyright (c) 2012 sta.blockhead
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#endregion
using System;
using System.Net.Security;
using System.Net.Sockets;
namespace WebSocketSharp.Net.Security {
public class SslStream : System.Net.Security.SslStream
{
#region Constructors
public SslStream(NetworkStream innerStream)
: base(innerStream)
{
}
public SslStream(NetworkStream innerStream, bool leaveInnerStreamOpen)
: base(innerStream, leaveInnerStreamOpen)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback
) : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback,
LocalCertificateSelectionCallback userCertificateSelectionCallback
) : base(
innerStream,
leaveInnerStreamOpen,
userCertificateValidationCallback,
userCertificateSelectionCallback
)
{
}
#endregion
#region Property
public bool DataAvailable {
get {
return ((NetworkStream)InnerStream).DataAvailable;
}
}
#endregion
}
}

View File

@ -643,6 +643,19 @@ namespace WebSocketSharp {
return frame;
}
private WsFrame readFrameWithTimeout()
{
if (!_wsStream.DataAvailable)
{
var timeout = 1 * 100;
Thread.Sleep(timeout);
if (!_wsStream.DataAvailable)
return null;
}
return readFrame();
}
private string[] readHandshake()
{
return _wsStream.ReadHandshake();
@ -650,7 +663,7 @@ namespace WebSocketSharp {
private MessageEventArgs receive()
{
var frame = readFrame();
var frame = _isClient ? readFrame() : readFrameWithTimeout();
if (frame == null)
return null;

View File

@ -31,11 +31,11 @@ using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using WebSocketSharp.Frame;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp
{
@ -51,7 +51,7 @@ namespace WebSocketSharp
#endregion
#region Constructor
#region Constructors
public WsStream(NetworkStream innerStream)
{
@ -65,7 +65,16 @@ namespace WebSocketSharp
#endregion
#region Public Property
#region Properties
public bool DataAvailable {
get {
if (_innerStreamType == typeof(SslStream))
return ((SslStream)_innerStream).DataAvailable;
return ((NetworkStream)_innerStream).DataAvailable;
}
}
public bool IsSecure {
get { return _isSecure; }
@ -73,7 +82,7 @@ namespace WebSocketSharp
#endregion
#region Private Methods
#region Private Method
private void init(Stream innerStream)
{
@ -98,7 +107,7 @@ namespace WebSocketSharp
if (port == 443)
{
RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
System.Net.Security.RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
{
// FIXME: Always returns true
return true;
@ -120,7 +129,7 @@ namespace WebSocketSharp
var port = ((IPEndPoint)client.Client.LocalEndPoint).Port;
if (port == 443)
{
var sslStream = new SslStream(netStream);
var sslStream = new SslStream(netStream, false);
var certPath = ConfigurationManager.AppSettings["ServerCertPath"];
sslStream.AuthenticateAsServer(new X509Certificate2(certPath));

View File

@ -110,6 +110,7 @@
<Compile Include="Server\IWebSocketServer.cs" />
<Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" />
<Compile Include="Server\WebSocketServerBase.cs" />
<Compile Include="Net\Security\SslStream.cs" />
</ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
<ItemGroup>
@ -117,5 +118,6 @@
<Folder Include="Server\" />
<Folder Include="Net\" />
<Folder Include="Net\Sockets\" />
<Folder Include="Net\Security\" />
</ItemGroup>
</Project>

Binary file not shown.