Skip to content

一、说明

先前写过相关文章,传送门👉 基于Socket的WebSocket协议实现👈

此源码是学习socketserver库后改造版本

二、源码实现

python
import base64
import hashlib
import struct
from socketserver import BaseRequestHandler, ThreadingTCPServer


class BaseWebSocketRequetHander(BaseRequestHandler):

    WEBSOCKET_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

    def __init__(self, request, client_address, server):
        self.key = None
        super().__init__(request, client_address, server)

    def get_headers(self, data):
        # 获取请求头中的Websocket-Key
        header_dict = {}
        header_str = data.decode("utf8")
        for line in header_str.split("\r\n"):
            if line.startswith("Sec-WebSocket-Key"):
                header_dict["Sec-WebSocket-Key"] = line.split(":")[1].strip()
        return header_dict

    def upgrade_websocket(self):
        """
        将链接改为websocket连接
        """
        if self.key is None:
            raise Exception("Sec-WebSocket-Accept is not null.")

        response_content = "HTTP/1.1 101 Switching Protocols\r\n" \
            "Upgrade: websocket\r\n" \
            "Connection: Upgrade\r\n" \
            "Sec-WebSocket-Accept: {key}\r\n" \
            "Websocket-Location: ws://{host}:{port}\r\n\r\n".\
            format(**{
                "key": self.key.decode("utf-8"),
                "host": self.server.server_address[0],
                "port": self.server.server_address[-1],
            })

        self.request.send(response_content.encode("utf-8"))

    def setup(self):
        """
        处理websocket

        GET / HTTP/1.1\r\n
        Host: 127.0.0.1:8888\r\n
        Connection: Upgrade\r\n
        Pragma: no-cache\r\n
        Cache-Control: no-cache\r\n
        User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.93 Safari/537.36 Edg/90.0.818.56\r\n
        Upgrade: websocket\r\n
        Origin: http://localhost:63342\r\n
        Sec-WebSocket-Version: 13\r\n
        Accept-Encoding: gzip, deflate, br\r\n
        Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6\r\n
        Sec-WebSocket-Key: 7iezRLm8rSVYShdK/uYauQ==\r\n
        Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n
        """
        # 获取请求头
        data = self.request.recv(1024)
        # 处理请求头
        value = self.get_headers(data).get(
            "Sec-WebSocket-Key") + BaseWebSocketRequetHander.WEBSOCKET_STRING
        self.key = base64.b64encode(
            hashlib.sha1(value.encode('utf-8')).digest())
        # 将链接升级为websocket
        self.upgrade_websocket()

    def websocket_decode(self, data):
        """
        数据解密
        :param data: 获取的字节流数据
        :return: 解析的结果
        """
        payload = data[1] & 127

        extend_payload_len = None
        mask = None
        decoded = None

        if payload == 127:
            extend_payload_len = data[2:10]  # 数据长度
            mask = data[10:14]  # 秘钥
            decoded = data[14:]

        elif payload == 126:
            extend_payload_len = data[2:4]  # 数据长度
            mask = data[4:8]  # 秘钥
            decoded = data[8:]

        elif payload <= 125:
            extend_payload_len = None  # 数据长度
            mask = data[2:6]  # 秘钥
            decoded = data[6:]

        # 开始解密
        str_byte = bytearray()
        for i in range(len(decoded)):
            byte = decoded[i] ^ mask[i % 4]
            str_byte.append(byte)

        return str_byte.decode("utf-8")

    def websocket_encode(self, data):
        """
        数据加密
        :param msg: 需要加密数据
        :return: 加密后数据
        """
        byte = data.encode("utf-8")
        token = b"\x81"
        length = len(byte)
        if length < 126:
            token += struct.pack("B", length)
        elif length <= 0xFFFF:
            token += struct.pack("!BH", 126, length)
        else:
            token += struct.pack("!BQ", 127, length)
        data = token + byte
        return data

    def handle(self):
        raise Exception("please overridden")


class DemoWebSocketRequetHander(BaseWebSocketRequetHander):
    def handle(self):
        while True:
            try:
                data = self.request.recv(8096)
                if data:
                    print(self.websocket_decode(data))
                    import random
                    self.request.send(self.websocket_encode(
                        "hello~{}".format(random.randint(1, 10000))))
            except Exception as e:
                print(e.args)
            except KeyboardInterrupt as e:
                self.request.close()


def main(server_address, RequestHandlerClass):
    with ThreadingTCPServer(server_address, RequestHandlerClass) as sock:
        sock.serve_forever()


if __name__ == "__main__":
    main(("localhost", 8881), DemoWebSocketRequetHander)

最后更新: