Fix due to the modified WebSocket.cs

This commit is contained in:
sta
2012-10-04 15:04:21 +09:00
parent ae5d461a42
commit 94385ea2bc
64 changed files with 313 additions and 252 deletions

View File

@@ -302,35 +302,45 @@ namespace WebSocketSharp {
if (length <= 0)
return new byte[]{};
var buffer = new byte[length];
stream.Read(buffer, 0, length);
return buffer;
var buffer = new byte[length];
var readLen = stream.Read(buffer, 0, length);
return readLen == length ? buffer : null;
}
public static byte[] ReadBytes(this Stream stream, long length, int bufferLength)
{
var count = length / bufferLength;
var rem = length % bufferLength;
var readData = new List<byte>();
var readLen = 0;
var buffer = new byte[bufferLength];
var count = length / bufferLength;
var rem = length % bufferLength;
var readData = new List<byte>();
var readBuffer = new byte[bufferLength];
long readLen = 0;
var tmpLen = 0;
Action<byte[]> read = (buffer) =>
{
tmpLen = stream.Read(buffer, 0, buffer.Length);
if (tmpLen > 0)
{
readLen += tmpLen;
readData.AddRange(buffer.SubArray(0, tmpLen));
}
};
count.Times(() =>
{
readLen = stream.Read(buffer, 0, bufferLength);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
read(readBuffer);
});
if (rem > 0)
{
buffer = new byte[rem];
readLen = stream.Read(buffer, 0, (int)rem);
if (readLen > 0)
readData.AddRange(buffer.SubArray(0, readLen));
readBuffer = new byte[rem];
read(readBuffer);
}
return readData.ToArray();
return readLen == length
? readData.ToArray()
: null;
}
public static T[] SubArray<T>(this T[] array, int startIndex, int length)

View File

@@ -28,6 +28,7 @@
using System;
using System.IO;
using System.Collections;
using System.Collections.Generic;
using System.Text;
@@ -35,7 +36,7 @@ namespace WebSocketSharp.Frame
{
public class WsFrame : IEnumerable<byte>
{
#region Private Static Fields
#region Field
private static readonly int _readBufferLen;
@@ -81,7 +82,7 @@ namespace WebSocketSharp.Frame
#endregion
#region Private Constructors
#region Private Constructor
private WsFrame()
{
@@ -109,49 +110,159 @@ namespace WebSocketSharp.Frame
public WsFrame(Fin fin, Opcode opcode, Mask mask, PayloadData payloadData)
: this()
{
Fin = fin;
Opcode = opcode;
ulong dataLength = payloadData.Length;
if (dataLength == 0)
{
Masked = Mask.UNMASK;
}
else
{
Masked = mask;
}
if (dataLength < 126)
{
PayloadLen = (byte)dataLength;
}
else if (dataLength < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)dataLength).ToBytes(ByteOrder.BIG);
}
else
{
PayloadLen = (byte)127;
ExtPayloadLen = dataLength.ToBytes(ByteOrder.BIG);
}
Fin = fin;
Opcode = opcode;
Masked = payloadData.Length != 0 ? mask : Mask.UNMASK;
PayloadData = payloadData;
if (Masked == Mask.MASK)
{
MaskingKey = new byte[4];
var rand = new Random();
rand.NextBytes(MaskingKey);
PayloadData.Mask(MaskingKey);
}
init();
}
#endregion
#region Public Static Methods
#region Private Methods
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void init()
{
setPayloadLen(PayloadLength);
if (Masked == Mask.MASK)
maskPayloadData();
}
private void maskPayloadData()
{
var key = new byte[4];
var rand = new Random();
rand.NextBytes(key);
MaskingKey = key;
PayloadData.Mask(key);
}
private static void readExtPayloadLen(Stream stream, WsFrame frame)
{
var length = frame.PayloadLen <= 125
? 0
: frame.PayloadLen == 126 ? 2 : 8;
if (length > 0)
{
var extLength = stream.ReadBytes(length);
if (extLength == null)
throw new IOException();
frame.ExtPayloadLen = extLength;
}
}
private static WsFrame readHeader(Stream stream)
{
var header = stream.ReadBytes(2);
if (header == null)
return null;
// FIN
Fin fin = (header[0] & 0x80) == 0x80 ? Fin.FINAL : Fin.MORE;
// RSV1
Rsv rsv1 = (header[0] & 0x40) == 0x40 ? Rsv.ON : Rsv.OFF;
// RSV2
Rsv rsv2 = (header[0] & 0x20) == 0x20 ? Rsv.ON : Rsv.OFF;
// RSV3
Rsv rsv3 = (header[0] & 0x10) == 0x10 ? Rsv.ON : Rsv.OFF;
// Opcode
Opcode opcode = (Opcode)(header[0] & 0x0f);
// MASK
Mask masked = (header[1] & 0x80) == 0x80 ? Mask.MASK : Mask.UNMASK;
// Payload len
byte payloadLen = (byte)(header[1] & 0x7f);
return new WsFrame {
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen};
}
private static void readMaskingKey(Stream stream, WsFrame frame)
{
if (frame.Masked == Mask.MASK)
{
var maskingKey = stream.ReadBytes(4);
if (maskingKey == null)
throw new IOException();
frame.MaskingKey = maskingKey;
}
}
private static void readPayloadData(Stream stream, WsFrame frame, bool unmask)
{
ulong length = frame.PayloadLen <= 125
? frame.PayloadLen
: frame.PayloadLen == 126
? frame.ExtPayloadLen.To<ushort>(ByteOrder.BIG)
: frame.ExtPayloadLen.To<ulong>(ByteOrder.BIG);
var buffer = length <= (ulong)_readBufferLen
? stream.ReadBytes((int)length)
: stream.ReadBytes((long)length, _readBufferLen);
if (buffer == null)
throw new IOException();
PayloadData payloadData;
if (frame.Masked == Mask.MASK)
{
payloadData = new PayloadData(buffer, true);
if (unmask == true)
{
payloadData.Mask(frame.MaskingKey);
frame.Masked = Mask.UNMASK;
frame.MaskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer);
}
frame.PayloadData = payloadData;
}
private void setPayloadLen(ulong length)
{
if (length < 126)
{
PayloadLen = (byte)length;
return;
}
if (length < 0x010000)
{
PayloadLen = (byte)126;
ExtPayloadLen = ((ushort)length).ToBytes(ByteOrder.BIG);
return;
}
PayloadLen = (byte)127;
ExtPayloadLen = length.ToBytes(ByteOrder.BIG);
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
yield return b;
}
public static WsFrame Parse(byte[] src)
{
@@ -166,183 +277,22 @@ namespace WebSocketSharp.Frame
}
}
public static WsFrame Parse<TStream>(TStream stream)
where TStream : System.IO.Stream
public static WsFrame Parse(Stream stream)
{
return Parse(stream, true);
}
public static WsFrame Parse<TStream>(TStream stream, bool unmask)
where TStream : System.IO.Stream
public static WsFrame Parse(Stream stream, bool unmask)
{
Fin fin;
Rsv rsv1, rsv2, rsv3;
Opcode opcode;
Mask masked;
byte payloadLen;
byte[] extPayloadLen = new byte[]{};
byte[] maskingKey = new byte[]{};
PayloadData payloadData;
byte[] buffer1, buffer2, buffer3;
int buffer1Len = 2;
int buffer2Len = 0;
ulong buffer3Len = 0;
int maskingKeyLen = 4;
int readLen = 0;
buffer1 = new byte[buffer1Len];
readLen = stream.Read(buffer1, 0, buffer1Len);
if (readLen < buffer1Len)
{
var frame = readHeader(stream);
if (frame == null)
return null;
}
// FIN
fin = (buffer1[0] & 0x80) == 0x80
? Fin.FINAL
: Fin.MORE;
// RSV1
rsv1 = (buffer1[0] & 0x40) == 0x40
? Rsv.ON
: Rsv.OFF;
// RSV2
rsv2 = (buffer1[0] & 0x20) == 0x20
? Rsv.ON
: Rsv.OFF;
// RSV3
rsv3 = (buffer1[0] & 0x10) == 0x10
? Rsv.ON
: Rsv.OFF;
// opcode
opcode = (Opcode)(buffer1[0] & 0x0f);
// MASK
masked = (buffer1[1] & 0x80) == 0x80
? Mask.MASK
: Mask.UNMASK;
// Payload len
payloadLen = (byte)(buffer1[1] & 0x7f);
// Extended payload length
if (payloadLen <= 125)
{
buffer3Len = payloadLen;
}
else if (payloadLen == 126)
{
buffer2Len = 2;
}
else
{
buffer2Len = 8;
}
readExtPayloadLen(stream, frame);
readMaskingKey(stream, frame);
readPayloadData(stream, frame, unmask);
if (buffer2Len > 0)
{
buffer2 = new byte[buffer2Len];
readLen = stream.Read(buffer2, 0, buffer2Len);
if (readLen < buffer2Len)
{
return null;
}
extPayloadLen = buffer2;
switch (buffer2Len)
{
case 2:
buffer3Len = extPayloadLen.To<ushort>(ByteOrder.BIG);
break;
case 8:
buffer3Len = extPayloadLen.To<ulong>(ByteOrder.BIG);
break;
}
}
if (buffer3Len > PayloadData.MaxLength)
{
throw new WsReceivedTooBigMessageException();
}
// Masking-key
if (masked == Mask.MASK)
{
maskingKey = new byte[maskingKeyLen];
readLen = stream.Read(maskingKey, 0, maskingKeyLen);
if (readLen < maskingKeyLen)
{
return null;
}
}
// Payload Data
if (buffer3Len == 0)
{
buffer3 = new byte[]{};
}
else if (buffer3Len <= (ulong)_readBufferLen)
{
buffer3 = new byte[buffer3Len];
readLen = stream.Read(buffer3, 0, (int)buffer3Len);
if (readLen < (int)buffer3Len)
{
return null;
}
}
else
{
buffer3 = stream.ReadBytes((long)buffer3Len, _readBufferLen);
if ((ulong)buffer3.LongLength < buffer3Len)
{
return null;
}
}
if (masked == Mask.MASK)
{
payloadData = new PayloadData(buffer3, true);
if (unmask == true)
{
payloadData.Mask(maskingKey);
masked = Mask.UNMASK;
maskingKey = new byte[]{};
}
}
else
{
payloadData = new PayloadData(buffer3);
}
return new WsFrame
{
Fin = fin,
Rsv1 = rsv1,
Rsv2 = rsv2,
Rsv3 = rsv3,
Opcode = opcode,
Masked = masked,
PayloadLen = payloadLen,
ExtPayloadLen = extPayloadLen,
MaskingKey = maskingKey,
PayloadData = payloadData
};
}
#endregion
#region Public Methods
public IEnumerator<byte> GetEnumerator()
{
foreach (byte b in ToBytes())
{
yield return b;
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
return frame;
}
public void Print()
@@ -462,33 +412,27 @@ namespace WebSocketSharp.Frame
public byte[] ToBytes()
{
var bytes = new List<byte>();
var buffer = new List<byte>();
int first16 = (int)Fin;
first16 = (first16 << 1) + (int)Rsv1;
first16 = (first16 << 1) + (int)Rsv2;
first16 = (first16 << 1) + (int)Rsv3;
first16 = (first16 << 4) + (int)Opcode;
first16 = (first16 << 1) + (int)Masked;
first16 = (first16 << 7) + (int)PayloadLen;
bytes.AddRange(((ushort)first16).ToBytes(ByteOrder.BIG));
int header = (int)Fin;
header = (header << 1) + (int)Rsv1;
header = (header << 1) + (int)Rsv2;
header = (header << 1) + (int)Rsv3;
header = (header << 4) + (int)Opcode;
header = (header << 1) + (int)Masked;
header = (header << 7) + (int)PayloadLen;
buffer.AddRange(((ushort)header).ToBytes(ByteOrder.BIG));
if (PayloadLen >= 126)
{
bytes.AddRange(ExtPayloadLen);
}
buffer.AddRange(ExtPayloadLen);
if (Masked == Mask.MASK)
{
bytes.AddRange(MaskingKey);
}
buffer.AddRange(MaskingKey);
if (PayloadLen > 0)
{
bytes.AddRange(PayloadData.ToBytes());
}
buffer.AddRange(PayloadData.ToBytes());
return bytes.ToArray();
return buffer.ToArray();
}
public override string ToString()

View File

@@ -31,13 +31,13 @@
using System;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Reflection;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp.Net {
@@ -122,7 +122,7 @@ namespace WebSocketSharp.Net {
if (!secure) {
stream = net_stream;
} else {
var ssl_stream = new SslStream(net_stream);
var ssl_stream = new SslStream(net_stream, false);
ssl_stream.AuthenticateAsServer(cert);
stream = ssl_stream;
}

View File

@@ -0,0 +1,83 @@
#region MIT License
/**
* SslStream.cs
*
* The MIT License
*
* Copyright (c) 2012 sta.blockhead
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#endregion
using System;
using System.Net.Security;
using System.Net.Sockets;
namespace WebSocketSharp.Net.Security {
public class SslStream : System.Net.Security.SslStream
{
#region Constructors
public SslStream(NetworkStream innerStream)
: base(innerStream)
{
}
public SslStream(NetworkStream innerStream, bool leaveInnerStreamOpen)
: base(innerStream, leaveInnerStreamOpen)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback
) : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
}
public SslStream(
NetworkStream innerStream,
bool leaveInnerStreamOpen,
RemoteCertificateValidationCallback userCertificateValidationCallback,
LocalCertificateSelectionCallback userCertificateSelectionCallback
) : base(
innerStream,
leaveInnerStreamOpen,
userCertificateValidationCallback,
userCertificateSelectionCallback
)
{
}
#endregion
#region Property
public bool DataAvailable {
get {
return ((NetworkStream)InnerStream).DataAvailable;
}
}
#endregion
}
}

View File

@@ -643,6 +643,19 @@ namespace WebSocketSharp {
return frame;
}
private WsFrame readFrameWithTimeout()
{
if (!_wsStream.DataAvailable)
{
var timeout = 1 * 100;
Thread.Sleep(timeout);
if (!_wsStream.DataAvailable)
return null;
}
return readFrame();
}
private string[] readHandshake()
{
return _wsStream.ReadHandshake();
@@ -650,7 +663,7 @@ namespace WebSocketSharp {
private MessageEventArgs receive()
{
var frame = readFrame();
var frame = _isClient ? readFrame() : readFrameWithTimeout();
if (frame == null)
return null;

View File

@@ -31,11 +31,11 @@ using System.Collections.Generic;
using System.Configuration;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using WebSocketSharp.Frame;
using WebSocketSharp.Net.Security;
namespace WebSocketSharp
{
@@ -51,7 +51,7 @@ namespace WebSocketSharp
#endregion
#region Constructor
#region Constructors
public WsStream(NetworkStream innerStream)
{
@@ -65,7 +65,16 @@ namespace WebSocketSharp
#endregion
#region Public Property
#region Properties
public bool DataAvailable {
get {
if (_innerStreamType == typeof(SslStream))
return ((SslStream)_innerStream).DataAvailable;
return ((NetworkStream)_innerStream).DataAvailable;
}
}
public bool IsSecure {
get { return _isSecure; }
@@ -73,7 +82,7 @@ namespace WebSocketSharp
#endregion
#region Private Methods
#region Private Method
private void init(Stream innerStream)
{
@@ -98,7 +107,7 @@ namespace WebSocketSharp
if (port == 443)
{
RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
System.Net.Security.RemoteCertificateValidationCallback validationCb = (sender, certificate, chain, sslPolicyErrors) =>
{
// FIXME: Always returns true
return true;
@@ -120,7 +129,7 @@ namespace WebSocketSharp
var port = ((IPEndPoint)client.Client.LocalEndPoint).Port;
if (port == 443)
{
var sslStream = new SslStream(netStream);
var sslStream = new SslStream(netStream, false);
var certPath = ConfigurationManager.AppSettings["ServerCertPath"];
sslStream.AuthenticateAsServer(new X509Certificate2(certPath));

View File

@@ -110,6 +110,7 @@
<Compile Include="Server\IWebSocketServer.cs" />
<Compile Include="Net\Sockets\TcpListenerWebSocketContext.cs" />
<Compile Include="Server\WebSocketServerBase.cs" />
<Compile Include="Net\Security\SslStream.cs" />
</ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
<ItemGroup>
@@ -117,5 +118,6 @@
<Folder Include="Server\" />
<Folder Include="Net\" />
<Folder Include="Net\Sockets\" />
<Folder Include="Net\Security\" />
</ItemGroup>
</Project>

Binary file not shown.