Source code for wsproto.handshake

# -*- coding: utf-8 -*-
"""
wsproto/handshake
~~~~~~~~~~~~~~~~~~

An implementation of WebSocket handshakes.
"""
from collections import deque
from typing import Deque, Dict, Generator, List, Optional, Union

import h11

from .connection import Connection, ConnectionState, ConnectionType
from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
from .extensions import Extension
from .typing import Headers
from .utilities import (
    generate_accept_token,
    generate_nonce,
    LocalProtocolError,
    normed_header_dict,
    RemoteProtocolError,
    split_comma_header,
)

# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
WEBSOCKET_VERSION = b"13"


[docs]class H11Handshake: """A Handshake implementation for HTTP/1.1 connections.""" def __init__(self, connection_type: ConnectionType) -> None: self.client = connection_type is ConnectionType.CLIENT self._state = ConnectionState.CONNECTING if self.client: self._h11_connection = h11.Connection(h11.CLIENT) else: self._h11_connection = h11.Connection(h11.SERVER) self._connection: Optional[Connection] = None self._events: Deque[Event] = deque() self._initiating_request: Optional[Request] = None self._nonce: Optional[bytes] = None @property def state(self) -> ConnectionState: return self._state @property def connection(self) -> Optional[Connection]: """Return the established connection. This will either return the connection or raise a LocalProtocolError if the connection has not yet been established. :rtype: h11.Connection """ return self._connection
[docs] def initiate_upgrade_connection(self, headers: Headers, path: str) -> None: """Initiate an upgrade connection. This should be used if the request has already be received and parsed. :param list headers: HTTP headers represented as a list of 2-tuples. :param str path: A URL path. """ if self.client: raise LocalProtocolError( "Cannot initiate an upgrade connection when acting as the client" ) upgrade_request = h11.Request(method=b"GET", target=path, headers=headers) h11_client = h11.Connection(h11.CLIENT) self.receive_data(h11_client.send(upgrade_request))
[docs] def send(self, event: Event) -> bytes: """Send an event to the remote. This will return the bytes to send based on the event or raise a LocalProtocolError if the event is not valid given the state. :returns: Data to send to the WebSocket peer. :rtype: bytes """ data = b"" if isinstance(event, Request): data += self._initiate_connection(event) elif isinstance(event, AcceptConnection): data += self._accept(event) elif isinstance(event, RejectConnection): data += self._reject(event) elif isinstance(event, RejectData): data += self._send_reject_data(event) else: raise LocalProtocolError( "Event {} cannot be sent during the handshake".format(event) ) return data
[docs] def receive_data(self, data: Optional[bytes]) -> None: """Receive data from the remote. A list of events that the remote peer triggered by sending this data can be retrieved with :meth:`events`. :param bytes data: Data received from the WebSocket peer. """ self._h11_connection.receive_data(data) while True: try: event = self._h11_connection.next_event() except h11.RemoteProtocolError: raise RemoteProtocolError( "Bad HTTP message", event_hint=RejectConnection() ) if ( isinstance(event, h11.ConnectionClosed) or event is h11.NEED_DATA or event is h11.PAUSED ): break if self.client: if isinstance(event, h11.InformationalResponse): if event.status_code == 101: self._events.append(self._establish_client_connection(event)) else: self._events.append( RejectConnection( headers=event.headers, status_code=event.status_code, has_body=False, ) ) self._state = ConnectionState.CLOSED elif isinstance(event, h11.Response): self._state = ConnectionState.REJECTING self._events.append( RejectConnection( headers=event.headers, status_code=event.status_code, has_body=True, ) ) elif isinstance(event, h11.Data): self._events.append( RejectData(data=event.data, body_finished=False) ) elif isinstance(event, h11.EndOfMessage): self._events.append(RejectData(data=b"", body_finished=True)) self._state = ConnectionState.CLOSED else: if isinstance(event, h11.Request): self._events.append(self._process_connection_request(event))
[docs] def events(self) -> Generator[Event, None, None]: """Return a generator that provides any events that have been generated by protocol activity. :returns: a generator that yields H11 events. """ while self._events: yield self._events.popleft()
############ Server mode methods def _process_connection_request( # noqa: MC0001 self, event: h11.Request ) -> Request: if event.method != b"GET": raise RemoteProtocolError( "Request method must be GET", event_hint=RejectConnection() ) connection_tokens = None extensions: List[str] = [] host = None key = None subprotocols: List[str] = [] upgrade = b"" version = None headers: Headers = [] for name, value in event.headers: name = name.lower() if name == b"connection": connection_tokens = split_comma_header(value) elif name == b"host": host = value.decode("ascii") continue # Skip appending to headers elif name == b"sec-websocket-extensions": extensions = split_comma_header(value) continue # Skip appending to headers elif name == b"sec-websocket-key": key = value elif name == b"sec-websocket-protocol": subprotocols = split_comma_header(value) continue # Skip appending to headers elif name == b"sec-websocket-version": version = value elif name == b"upgrade": upgrade = value headers.append((name, value)) if connection_tokens is None or not any( token.lower() == "upgrade" for token in connection_tokens ): raise RemoteProtocolError( "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() ) if version != WEBSOCKET_VERSION: raise RemoteProtocolError( "Missing header, 'Sec-WebSocket-Version'", event_hint=RejectConnection( headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)], status_code=426, ), ) if key is None: raise RemoteProtocolError( "Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection() ) if upgrade.lower() != b"websocket": raise RemoteProtocolError( "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection() ) if version is None: raise RemoteProtocolError( "Missing header, 'Sec-WebSocket-Version'", event_hint=RejectConnection() ) if host is None: raise RemoteProtocolError( "Missing header, 'Host'", event_hint=RejectConnection() ) self._initiating_request = Request( extensions=extensions, extra_headers=headers, host=host, subprotocols=subprotocols, target=event.target.decode("ascii"), ) return self._initiating_request def _accept(self, event: AcceptConnection) -> bytes: # _accept is always called after _process_connection_request. assert self._initiating_request is not None request_headers = normed_header_dict(self._initiating_request.extra_headers) nonce = request_headers[b"sec-websocket-key"] accept_token = generate_accept_token(nonce) headers = [ (b"Upgrade", b"WebSocket"), (b"Connection", b"Upgrade"), (b"Sec-WebSocket-Accept", accept_token), ] if event.subprotocol is not None: if event.subprotocol not in self._initiating_request.subprotocols: raise LocalProtocolError( "unexpected subprotocol {}".format(event.subprotocol) ) headers.append( (b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii")) ) if event.extensions: accepts = server_extensions_handshake( # type: ignore self._initiating_request.extensions, event.extensions ) if accepts: headers.append((b"Sec-WebSocket-Extensions", accepts)) response = h11.InformationalResponse( status_code=101, headers=headers + event.extra_headers ) self._connection = Connection( ConnectionType.CLIENT if self.client else ConnectionType.SERVER, event.extensions, ) self._state = ConnectionState.OPEN return self._h11_connection.send(response) def _reject(self, event: RejectConnection) -> bytes: if self.state != ConnectionState.CONNECTING: raise LocalProtocolError( "Connection cannot be rejected in state %s" % self.state ) headers = event.headers if not event.has_body: headers.append((b"content-length", b"0")) response = h11.Response(status_code=event.status_code, headers=headers) data = self._h11_connection.send(response) self._state = ConnectionState.REJECTING if not event.has_body: data += self._h11_connection.send(h11.EndOfMessage()) self._state = ConnectionState.CLOSED return data def _send_reject_data(self, event: RejectData) -> bytes: if self.state != ConnectionState.REJECTING: raise LocalProtocolError( "Cannot send rejection data in state {}".format(self.state) ) data = self._h11_connection.send(h11.Data(data=event.data)) if event.body_finished: data += self._h11_connection.send(h11.EndOfMessage()) self._state = ConnectionState.CLOSED return data ############ Client mode methods def _initiate_connection(self, request: Request) -> bytes: self._initiating_request = request self._nonce = generate_nonce() headers = [ (b"Host", request.host.encode("ascii")), (b"Upgrade", b"WebSocket"), (b"Connection", b"Upgrade"), (b"Sec-WebSocket-Key", self._nonce), (b"Sec-WebSocket-Version", WEBSOCKET_VERSION), ] if request.subprotocols: headers.append( ( b"Sec-WebSocket-Protocol", (", ".join(request.subprotocols)).encode("ascii"), ) ) if request.extensions: offers = {e.name: e.offer() for e in request.extensions} # type: ignore extensions = [] for name, params in offers.items(): name = name.encode("ascii") if isinstance(params, bool): if params: extensions.append(name) else: extensions.append(b"%s; %s" % (name, params.encode("ascii"))) if extensions: headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions))) upgrade = h11.Request( method=b"GET", target=request.target.encode("ascii"), headers=headers + request.extra_headers, ) return self._h11_connection.send(upgrade) def _establish_client_connection( self, event: h11.InformationalResponse ) -> AcceptConnection: # noqa: MC0001 # _establish_client_connection is always called after _initiate_connection. assert self._initiating_request is not None assert self._nonce is not None accept = None connection_tokens = None accepts: List[str] = [] subprotocol = None upgrade = b"" headers: Headers = [] for name, value in event.headers: name = name.lower() if name == b"connection": connection_tokens = split_comma_header(value) continue # Skip appending to headers elif name == b"sec-websocket-extensions": accepts = split_comma_header(value) continue # Skip appending to headers elif name == b"sec-websocket-accept": accept = value continue # Skip appending to headers elif name == b"sec-websocket-protocol": subprotocol = value continue # Skip appending to headers elif name == b"upgrade": upgrade = value continue # Skip appending to headers headers.append((name, value)) if connection_tokens is None or not any( token.lower() == "upgrade" for token in connection_tokens ): raise RemoteProtocolError( "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() ) if upgrade.lower() != b"websocket": raise RemoteProtocolError( "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection() ) accept_token = generate_accept_token(self._nonce) if accept != accept_token: raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection()) if subprotocol is not None: subprotocol = subprotocol.decode("ascii") if subprotocol not in self._initiating_request.subprotocols: raise RemoteProtocolError( "unrecognized subprotocol {}".format(subprotocol), event_hint=RejectConnection(), ) extensions = client_extensions_handshake( # type: ignore accepts, self._initiating_request.extensions ) self._connection = Connection( ConnectionType.CLIENT if self.client else ConnectionType.SERVER, extensions, self._h11_connection.trailing_data[0], ) self._state = ConnectionState.OPEN return AcceptConnection( extensions=extensions, extra_headers=headers, subprotocol=subprotocol ) def __repr__(self) -> str: return "{}(client={}, state={})".format( self.__class__.__name__, self.client, self.state )
[docs]def server_extensions_handshake( requested: List[str], supported: List[Extension] ) -> Optional[bytes]: """Agree on the extensions to use returning an appropriate header value. This returns None if there are no agreed extensions """ accepts: Dict[str, Union[bool, bytes]] = {} for offer in requested: name = offer.split(";", 1)[0].strip() for extension in supported: if extension.name == name: accept = extension.accept(offer) if isinstance(accept, bool): if accept: accepts[extension.name] = True elif accept is not None: accepts[extension.name] = accept.encode("ascii") if accepts: extensions: List[bytes] = [] for name, params in accepts.items(): name_bytes = name.encode("ascii") if isinstance(params, bool): assert params extensions.append(name_bytes) else: if params == b"": extensions.append(b"%s" % (name_bytes)) else: extensions.append(b"%s; %s" % (name_bytes, params)) return b", ".join(extensions) return None
[docs]def client_extensions_handshake( accepted: List[str], supported: List[Extension] ) -> List[Extension]: # This raises RemoteProtocolError is the accepted extension is not # supported. extensions = [] for accept in accepted: name = accept.split(";", 1)[0].strip() for extension in supported: if extension.name == name: extension.finalize(accept) extensions.append(extension) break else: raise RemoteProtocolError( "unrecognized extension {}".format(name), event_hint=RejectConnection() ) return extensions