Check if the received frame is correctly masked or not

This commit is contained in:
sta 2014-09-16 14:53:26 +09:00
parent dd25d1d0a4
commit 130b0a08d0
2 changed files with 39 additions and 9 deletions

View File

@ -648,7 +648,19 @@ namespace WebSocketSharp
private bool concatenateFragmentsInto (Stream destination) private bool concatenateFragmentsInto (Stream destination)
{ {
while (true) { while (true) {
var frame = WebSocketFrame.Read (_stream, true); var frame = WebSocketFrame.Read (_stream, false);
var masked = frame.IsMasked;
if (_client && masked)
return processUnsupportedFrame (
frame, CloseStatusCode.ProtocolError, "A frame from the server is masked.");
if (!_client && !masked)
return processUnsupportedFrame (
frame, CloseStatusCode.ProtocolError, "A frame from a client isn't masked.");
if (masked)
frame.Unmask ();
if (frame.IsFinal) { if (frame.IsFinal) {
/* FINAL */ /* FINAL */
@ -1002,6 +1014,18 @@ namespace WebSocketSharp
private bool processWebSocketFrame (WebSocketFrame frame) private bool processWebSocketFrame (WebSocketFrame frame)
{ {
var masked = frame.IsMasked;
if (_client && masked)
return processUnsupportedFrame (
frame, CloseStatusCode.ProtocolError, "A frame from the server is masked.");
if (!_client && !masked)
return processUnsupportedFrame (
frame, CloseStatusCode.ProtocolError, "A frame from a client isn't masked.");
if (masked)
frame.Unmask ();
return frame.IsCompressed && _compression == CompressionMethod.None return frame.IsCompressed && _compression == CompressionMethod.None
? processUnsupportedFrame ( ? processUnsupportedFrame (
frame, frame,
@ -1285,7 +1309,7 @@ namespace WebSocketSharp
Action receive = null; Action receive = null;
receive = () => WebSocketFrame.ReadAsync ( receive = () => WebSocketFrame.ReadAsync (
_stream, _stream,
true, false,
frame => { frame => {
if (processWebSocketFrame (frame) && _readyState != WebSocketState.Closed) { if (processWebSocketFrame (frame) && _readyState != WebSocketState.Closed) {
receive (); receive ();

View File

@ -516,14 +516,10 @@ Extended Payload Length: {7}
data = new byte[0]; data = new byte[0];
} }
var payload = new PayloadData (data, masked); frame._payloadData = new PayloadData (data, masked);
if (masked && unmask) { if (unmask && masked)
payload.Mask (maskingKey); frame.Unmask ();
frame._mask = Mask.Unmask;
frame._maskingKey = new byte[0];
}
frame._payloadData = payload;
return frame; return frame;
} }
@ -606,6 +602,16 @@ Extended Payload Length: {7}
error); error);
} }
internal void Unmask ()
{
if (_mask == Mask.Unmask)
return;
_payloadData.Mask (_maskingKey);
_maskingKey = new byte[0];
_mask = Mask.Unmask;
}
#endregion #endregion
#region Public Methods #region Public Methods