Golang的websocket使用和实现代码分析
2023-12-15 08:55:7 Author: Go语言中文网(查看原文) 阅读量:5 收藏

前言

【你不知道的websocket协议,这次给你讲明白!】中介绍了web端即时通讯的方式,以及websocket如何进行连接、验证、数据帧的格式,这些都是了解websocket的基础知识。

本期将会继续上次话题,上篇主要是理论还是停留在文字层面,今天带来的是websocket实操,分享它使用和底层实现!

📚 全文字数 : 5k+

⏳ 阅读时长 : 7min

📢 关键词 : gorilla/websocket 、数据帧、Golang

相信很多使用Golang的小伙伴都知道Gorilla这个工具包,长久以来gorilla/websocket 都是比官方包更好的websocket包。

题外话 gorilla:大猩猩(不过这个猩猩还挺可爱的)

gorilla/websocket 框架开源地址为: https://github.com/gorilla/websocket

今天小许就用【gorilla/websocket】框架来展开本期文章内容,文章会涉及到核心代码的走读,会涉及到不少代码,需要小伙伴们保持耐心往下看,然后结合之前分享的websocket基础,彻底学个明白!

简单使用

安装Gorilla Websocket Go软件包,您只需要使用即可go get

go get github.com/gorilla/websocket

在正式使用之前我们先简单了解下两个数据结构 Upgrader 和 Conn

Upgrader

Upgrader指定用于将 HTTP 连接升级到 WebSocket 连接

type Upgrader struct {
    
    HandshakeTimeout time.Duration
    
    ReadBufferSize, WriteBufferSize int

    WriteBufferPool BufferPool

    Subprotocols []string

    Error func(w http.ResponseWriter, r *http.Request, status int, reason error)

    CheckOrigin func(r *http.Request) bool

    EnableCompression bool
}

  • • HandshakeTimeout:握手完成的持续时间

  • • ReadBufferSize和WriteBufferSize:以字节为单位指定I/O缓冲区大小。如果缓冲区大小为零,则使用HTTP服务器分配的缓冲区

  • • CheckOrigin :函数应仔细验证请求来源 防止跨站点请求伪造

这里一般会设置下CheckOrigin来解决跨域问题

Conn

Conn类型表示WebSocket连接,这个结构体的组成包括两部分,写入字段(Write fields)和 读取字段(Read fields)

type Conn struct {
    conn        net.Conn
    isServer    bool
    ...

    // Write fields
    writeBuf      []byte        
    writePool     BufferPool
    writeBufSize  int
    writer        io.WriteCloser 
    isWriting     bool           
    ...
    // Read fields
    readRemaining int64
    readFinal     bool  
    readLength    int64 
    messageReader *messageReader 
    ...
}

isServer :字段来区分我们是否用Conn作为客户端还是服务端,也就是说说gorilla/websocket中同时编写客户端程序和服务器程序,但是一般是Web应用程序使用单独的前端作为客户端程序。

部分字段说明如下图:

服务端示例

出于说明的目的,我们将在Go中同时编写客户端程序和服务端程序(其实小许是前端小趴菜😅 🤭)。

当然我们在开发程序的时候基本都是单独的前端,通常使用(Javascript,vue等)实现websocket客户端,这里为了让大家有比较直观的感受,用【gorilla/websocket】分别写了服务端和客户端示例。

var upGrader = websocket.Upgrader{
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
}

func main() {
    http.HandleFunc("/ws", wsUpGrader)
    err := http.ListenAndServe("localhost:8080", nil)
    if err != nil {
        log.Println("server start err", err)
    }
}

func wsUpGrader(w http.ResponseWriter, r *http.Request) {
    //转换为升级为websocket
    conn, err := upGrader.Upgrade(w, r, nil)
    if err != nil {
        log.Println(err)
        return
    }
    //释放连接
    defer conn.Close()

    for {
        //接收消息
        messageType, message, err := conn.ReadMessage()
        if err != nil {
            log.Println(err)
            return
        }
        log.Println("server receive messageType", messageType, "message", string(message))
        //发送消息
        err = conn.WriteMessage(messageType, []byte("pong"))
        if err != nil {
            log.Println(err)
            return
        }
    }
}

我们知道websocket协议是基于http协议进行upgrade升级的, 这里使用 net/http提供原始的http连接。

http.HandleFunc接受两个参数:第一个参数是字符串表示的 url 路径,第二个参数是该 url 实际的处理对象

http.ListenAndServe 监听在某个端口,启动服务,准备接受客户端的请求

HandleFunc的作用:通过类型转换让我们可以将普通的函数作为HTTP处理器使用

服务端代码流程:

  • • Gorilla在使用websocket之前是先将初始化的upGrader结构体变量调用Upgrade方法进行请求协议升级

  • • 升级后返回 *Conn(此时isServer = true),后续使用它来处理websocket连接

  • • 服务端消息读写分别用 ReadMessage()、WriteMessage()

客户端示例

import (
    "fmt"
    "github.com/gorilla/websocket"
    "log"
    "time"
)

func main() {
    //服务器地址 websocket 统一使用 ws://
    url := "ws://localhost:8080/ws" 
    //使用默认拨号器,向服务器发送连接请求
    ws, _, err := websocket.DefaultDialer.Dial(url, nil)
    if err != nil {
        log.Fatal(err)
    }
    //关闭连接
    defer conn.Close()
    //发送消息
    go func() {
        for {
            err := ws.WriteMessage(websocket.BinaryMessage, []byte("ping"))
            if err != nil {
                log.Fatal(err)
            }
            //休眠两秒
            time.Sleep(time.Second * 2)
        }
    }()

    //接收消息
    for {
        _, data, err := ws.ReadMessage()
        if err != nil {
            log.Fatal(err)
        }
        fmt.Println("client receive message: ", string(data))
    }
}

客户端的实现看起来也是简单,先使用默认拨号器,向服务器地址发送连接请求,拨号成功时也返回一个*Conn,开启一个协程每隔两秒向服务端发送消息,同样都是使用ReadMessage和W riteMessage读写消息。

示例代码运行结果如下:

源码走读

看完上面基本的客户端和服务端案例之后,我们对整个消息发送和接收的使用已经熟悉了,实际开发中要做的就是如何结合业务去定义消息类型和发送场景了,我们接着走读下底层的实现逻辑!

代码走读我们分了四部分,主要了解协议是如何升级、已经消息如何读写、解析数据帧【 🚩 🚩核心】!

Upgrade 协议升级

Upgrade顾名思义【升级】,在进行协议升级之前是需要对协议进行校验的,之前我们知道待升级的http请求是有固定请求头的,这里列举几个:

✏️ Upgrade进行校验的目的是看该请求是否符合协议升级的规定

Upgrade的部分校验代码如下,return处进行了省略

func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {

    if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
           return ...
    }
    if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
        return ...
    }
    //必须是get请求方法
    if r.Method != http.MethodGet {
           return ...
    }

    if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
        return ...
    }

    if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
        return ...
    }
    ...
    c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
    ...
}

tokenListContainsValue的目的是校验请求的Header中是否有upgrade需要的特定参数,比如我们上图列举的一些。

newConn就是初始化部分Conn结构体的,方法中的第二个参数为true代表这是服务端

computeAcceptKey 计算接受密钥:

这个函数重点说下,在上一期中在websocket【连接确认】这一章节中知道,websocket协议升级时,需要满足如下条件:

✏️只有当请求头参数Sec-WebSocket-Key字段的值经过固定算法加密后的数据和响应头里的Sec-WebSocket-Accept的值保持一致,该连接才会被认可建立。

var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func computeAcceptKey(challengeKey string) string {
    h := sha1.New() 
    h.Write([]byte(challengeKey))
    h.Write(keyGUID)
    return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

上面 computeAcceptKey 函数的实现,验证了之前说的关于 Sec-WebSocket-Accept的生成

服务端需将Sec-WebSocket-Key和固定的 GUID 字符串( 258EAFA5-E914-47DA-95CA-C5AB0DC85B11) 拼接后使用 SHA-1 进行哈希,并采用 base64 编码后返回

ReadMessage 读消息

ReadMessage方法内部使用NextReader获取读取器并从该读取器读取到缓冲区,如果是一条消息由多个数据帧,则会拼接成完整的消息,返回给业务层。

func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
    var r io.Reader
    messageType, r, err = c.NextReader()
    if err != nil {
        return messageType, nil, err
    }
    //ReadAll从r读取,直到出现错误或EOF,并返回读取的数据
    p, err = io.ReadAll(r)
    return messageType, p, err
}

该方法,返回三个参数,分别是消息类型、内容、error

messageType是int型,值可能是 BinaryMessage(二进制消息) 或 TextMessage(文本消息)

NextReader: 该方法得到一个消息类型 messageType,io.Reader,err

func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
        ...
        for c.readErr == nil {
        //解析数据帧方法advanceFrame
        // frameType : 帧类型
        frameType, err := c.advanceFrame()
        if err != nil {
            c.readErr = hideTempErr(err)
            break
        }
        //数据类型是 文本或二进制类型
        if frameType == TextMessage || frameType == BinaryMessage {
            c.messageReader = &messageReader{c}
            c.reader = c.messageReader
            if c.readDecompress {
                c.reader = c.newDecompressionReader(c.reader)
            }
            return frameType, c.reader, nil
        }
    }
    ...
}

c.advanceFrame() 是核心代码,主要是实现解析这条消息,这里在最后章节会讲。

这里有个 c.messageReader (当前的低级读取器),赋值给c.reader,为什么要这样呢?

c.messageReader 是更低级读取器,而 c.reader 的作用是当前读取器返回到应用程序。简单就是messageReader 是实现了 c.reader 接口的结构体, 从而也实现了 io.Reader接口

图上加一个 bufio.Read方法:Read读取数据写入p。本方法返回写入p的字节数。本方法一次调用最多会调用下层Reader接口一次Read方法,因此返回值n可能小于len(p)。读取到达结尾时,返回值n将为0而err将为io.EOF

messageReader的 Read方法: 我们看下Read的具体实现,Read方法主要是读取数据帧内容,直到出现并返回io.EOF或者其他错误为止,而实际调用它的正是 io.ReadAll。

func (r *messageReader) Read(b []byte) (int, error) {
    ...
    for c.readErr == nil {
        //当前帧中剩余的字节
        if c.readRemaining > 0 {
            if int64(len(b)) > c.readRemaining {
                b = b[:c.readRemaining]
            }
            //读取到切片b中
            n, err := c.br.Read(b)
            c.readErr = hideTempErr(err)
            //当Conn是服务端
            if c.isServer {
                c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
            }
            //readRemaining字节数转int64
            rem := c.readRemaining
            rem -= int64(n)
            //跟踪连接上剩余的字节数
            if err := c.setReadRemaining(rem); err != nil {
                return 0, err
            }
            if c.readRemaining > 0 && c.readErr == io.EOF {
                c.readErr = errUnexpectedEOF
            }
            //返回读后字节数
            return n, c.readErr
        }
        //标记是否最后一个数据帧
        if c.readFinal {
            // messageRader 置为nil
            c.messageReader = nil
            return 0, io.EOF
        }
        //获取数据帧类型
        frameType, err := c.advanceFrame()
        switch {
        case err != nil:
            c.readErr = hideTempErr(err)
        case frameType == TextMessage || frameType == BinaryMessage:
            c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
        }
    }

    err := c.readErr
    if err == io.EOF && c.messageReader == r {
        err = errUnexpectedEOF
    }
    return 0, err
}

io.ReadAll : ReadAll从r读取,这里是实现如果一条消息由多个数据帧,会一直读直到最后一帧的关键。

func ReadAll(r Reader) ([]byte, error) {
    b := make([]byte, 0, 512)
    for {
        if len(b) == cap(b) {
            // 给[]byte添加更多容量
            b = append(b, 0)[:len(b)]
        }
        n, err := r.Read(b[len(b):cap(b)])
        b = b[:len(b)+n]
        if err != nil {
            if err == EOF {
                err = nil
            }
            return b, err
        }
    }
}

可以看出在for 循环中一直读取,直至读取到最后一帧,直到返回io.EOF或网络原因错误为止,否则一直进行阻塞读,这些 error 可以从上面讲到的messageReader的 Read方法可以看出来。

总结下,整个流程如下:

整个读消息的流程就结束了,我们继续看如何写消息!

WriteMessage 写消息

既然读消息是对数据帧进行解析,那么写消息就自然会联想到将数据按照数据帧的规范组装写入到一个writebuf中,然后写入到网络中。

我们继续看WriteMessage是如何实现的

func (c *Conn) WriteMessage(messageType int, data []byte) error {
    ...
    //w 是一个io.WriteCloser
    w, err := c.NextWriter(messageType)
    if err != nil {
        return err
    }
    //将data写入writeBuf中
    if _, err = w.Write(data); err != nil {
        return err
    }
    return w.Close()
}

WriteMessage方法接收一个消息类型和数据,主要逻辑是先调用Conn的NextWriter方法得到一个io.WriteCloser,然后写消息到这个Conn的writeBuf,写完消息后close它。

NextWriter实现如下:

func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
    var mw messageWriter
    if err := c.beginMessage(&mw, messageType); err != nil {
        return nil, err
    }
    c.writer = &mw
    ...
    return c.writer, nil
}

注意看这里有个messageWriter赋值给了Conn的writer,也就是说messageWriter实现了io.WriterCloser接口。

这里的实现跟读消息中的NextReader方法中的messageReader很像,也是通过实现io.Reader接口,然后赋值给了Conn的Reader,这里可以做个小联动,找到读写消息实际的实现者 messageReader、messageWriter

messageWriter的Write实现:

前置知识:如果没有设置Conn中writeBufferSize, 默认情况下会设置为 4096个字节,另外加上14字节的数据帧头部大小【这些在newConn中初始化的时候有代码说明】

func (w *messageWriter) Write(p []byte) (int, error) {
    ...
    //如果字节长度大于初始化的writeBuf空间大小
    if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
        //写入方法
        err := w.flushFrame(false, p)
        ...
    }
    //字节长度不大于初始化的writeBuf空间大小
    nn := len(p)
    for len(p) > 0 {
        //内部也是调用的flushFrame
        n, err := w.ncopy(len(p))
        ...
    }
    return nn, nil
}

messageWriter中的Write方法主要的目的是将数据写入到writeBuf中,它主要存储结构化的数据帧内容,所谓结构化就是按照数据帧的格式,用Go实现写入的。

总结下,整个流程如下:

而flushFrame方法将缓冲数据和额外数据作为帧写入网络,这个final参数表示这是消息中的最后一帧。

至于flushFrame内部是如何实现写入网络中的,你可以看看 net.Conn 是怎么Write的,因为最终就是调这个写入网络的,这里就不再深究了,有兴趣的同学可以自己挖一挖!

advanceFrame 解析数据帧

解析数据帧放在最后,前面的代码走读主要是为了方便大家能把整体流程搞清楚,而数据帧的解析,是更加需要对websocket基础有了解,特别是数据帧的组成,因为解析就是按照协定用Go代码实现的一种方式而已!

强烈推荐大家看完# 为什么有了http,还需要websocket,懂了!]

根据上图【来自网络】回顾下数据帧各部分代表的意思:

FIN : 1个bit位,用来标记当前数据帧是不是最后一个数据帧

RSV1, RSV2, RSV3 :这三个各占用一个bit位用做扩展用途,没有这个需求的话设置为0

Opcode : 该值定义的是数据帧的数据类型 1 表示文本 2 表示二进制

MASK:表示数据有没有使用掩码

Payload length :数据的长度,Payload data的长度,占7bits,7+16bits,7+64bits

Masking-key :数据掩码 (设置为0,则该部分可以省略,如果设置为1,则用来解码客户端发送给服务端的数据帧)

Payload data : 帧真正要发送的数据,可以是任意长度

advanceFrame 解析方法

实现代码会比较长,如果直接贴代码,会看不下去,该方法返回数据类型和error, 这里我们只会截取其中一部分

func (c *Conn) advanceFrame() (int, error) {
    ...
    //读取前两个字节
    p, err := c.read(2)
    if err != nil {
        return noFrame, err
    }
    //数据帧类型
    frameType := int(p[0] & 0xf)
    // FIN 标记位
    final := p[0]&finalBit != 0
    //三个扩展用
    rsv1 := p[0]&rsv1Bit != 0
    rsv2 := p[0]&rsv2Bit != 0
    rsv3 := p[0]&rsv3Bit != 0
    //mask :是否使用掩码
    mask := p[1]&maskBit != 0
    ...
    switch c.readRemaining {
    case 126:
        p, err := c.read(2)
        if err != nil {
            return noFrame, err
        }

        if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
            return noFrame, err
        }
    case 127:
        p, err := c.read(8)
        if err != nil {
            return noFrame, err
        }

        if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
            return noFrame, err
        }
    }
    ..
}

整个流程分为了 7 个部分:

  1. 1. 跳过前一帧的剩余部分,毕竟这是之前帧的数据

  2. 2. 读取并解析帧头的前两个字节(从上面图中可以看出只读取到 Payload len)

  3. 3. 根据读取和解析帧长度(根据 Payload length的值来获取Payload data的长度)

  4. 4. 处理数据帧的mask掩码

  5. 5. 如果是文本和二进制消息,强制执行读取限制并返回 (结束)

  6. 6. 读取控制帧有效载荷 即 play data,设置setReadRemaining以安全地更新此值并防止溢出

  7. 7. 过程控制帧有效载荷,如果是ping/pong/close消息类型,返回 -1 (noFrame) (结束)

advanceFrame方法的主要目的就是解析数据帧,获取数据帧的消息类型,而对于数据帧的解析都是按照上图帧格式来的!

heartbeat 心跳

WebSocket 为了确保客户端、服务端之间的 TCP 通道连接没有断开,使用心跳机制来判断连接状态。如果超时时间内没有收到应答则认为连接断开,关闭连接,释放资源。流程如下

  • • 发送方 -> 接收方:ping

  • • 接收方 -> 发送方:pong

ping、pong 消息:它们对应的是 WebSocket 的两个控制帧,opcode分别是0x9、0xA,对应的消息类型分别是PingMessage, PongMessage,前提是应用程序需要先读取连接中的消息才能处理从对等方发送的 close、ping 和 pong 消息。

⏰⏰ 当然关于源码的部分我只是拿了其中一部分比如:控制类消息、并发、缓冲等,大家要知道有这些功能,有兴趣的可以去看看

总结

本期主要和大家一起了解 gorilla/websocket 框架的使用和部分底层实现原理代码走读,通篇读下来想必大家对websocket用程序语言实现有了更深刻的认识吧!

不过流行的开源 Go 语言 Web 工具包 Gorilla 宣布已正式归档,目前已进入只读模式。“它发出的信号是,这些库在未来将不会有任何发展。

也就是说 gorilla/websocket 这个被广泛使用的 websocket 库也会停止更新了,真是个令人悲伤的消息!

正如作者所说的那样:“没有一个项目需要永远存在。这可能不会让每个人都开心,但生活就是这样。”

好了,通过两期对websocket的讲解,相信大家心里已经对它有了比较深刻的印象,还是那句话知道的越多,不知道的也越多,一起前行让自己知道的更多一点


推荐阅读

福利
我为大家整理了一份从入门到进阶的Go学习资料礼包,包含学习建议:入门看什么,进阶看什么。关注公众号 「polarisxu」,回复 ebook 获取;还可以回复「进群」,和数万 Gopher 交流学习。


文章来源: http://mp.weixin.qq.com/s?__biz=MzAxMTA4Njc0OQ==&mid=2651454848&idx=1&sn=54570e431a29c66722a77d724abb9c0b&chksm=81c507bc43170233425d998837ae36b66aa720751b26336143cb82f390915a8bfbd0f065604d&scene=0&xtrack=1#rd
如有侵权请联系:admin#unsafe.sh