编程技术分享

  • 首页
  1. 首页
  2. MySQL
  3. 正文

MySQL源码分析系列3——登录协议解析

2021年12月25日 2100点热度 0人点赞 19条评论

1、发送接收数据包流程

在介绍登录流程之前,让我们先看看mysql发送、接收数据包流程,mysql数据包有固定的协议格式,即每个数据包都包含一个4字节包头,其中前三个字节指定数据包大小,最后一个字节指定数据包序列号,序列号用于保证数据包的顺序,如下图所示:

1.1  发送数据包

//文件net_serv.cc
//发送逻辑数据包,将逻辑数据包按照大小0xffffff(16M)分割为一个或多个物理数据包,物理数据包增加数据头部,头部包括包长度、包序号。
my_bool my_net_write(NET *net, const uchar *packet, size_t len) {
    ...
    net_write_buff(net, buff, NET_HEADER_SIZE)
    net_write_buff(net, packet, z_size)
    ...
}
//文件net_serv.cc
//缓冲区物理数据包,缓冲区满自动发送,或者调用net_flush主动发送。
static my_bool net_write_buff(NET *net, const uchar *packet, size_t len) {
    ...
    net_write_packet(net, net->buff, (size_t) (net->write_pos - net->buff) + left_length)	
    ...
}
//文件net_serv.cc
//发送物理数据包。
my_bool net_write_packet(NET *net, const uchar *packet, size_t length) {
    ...
    res= net_write_raw_loop(net, packet, length);
    ...
}
//文件net_serv.cc
//发送指定长度字节数据。
static my_bool net_write_raw_loop(NET *net, const uchar *buf, size_t count) {
    ...
    size_t sentcnt= vio_write(net->vio, buf, count);
    ...
}

1.2 接收数据包

//文件net_serv.cc
//读取逻辑数据包,可能由多个物理包组成,通过物理包的头部包长度是否为最大值0xffffff判断是否有后继包。
ulong my_net_read(NET *net) {
    ...
    len= net_read_packet(net, &complen);
    ...
}
//文件net_serv.cc
//读取一个物理数据包,调用net_read_packet_header读取数据包头。
static size_t net_read_packet(NET *net, size_t *complen) {
    ...
    net_read_packet_header(net)
    ...
    net_read_raw_loop(net, pkt_len)
    ...
}
//文件net_serv.cc
//读取物理数据包头部
static my_bool net_read_packet_header(NET *net) {
    ...
    rc= net_read_raw_loop(net, count)
    ...
}
//文件net_serv.cc
//读取指定长度字节数据。
static my_bool net_read_raw_loop(NET *net, size_t count) {
    ...
    size_t recvcnt= vio_read(net->vio, buf, count);
    ...
}

2、登录流程

登录流程主要包括:客户端连接服务器、服务器发送随机码到客户端、客户端发送用户密码到服务器、服务器返回校验结果,如下图所示。

登录流程调用链:

//sql_connect.cc
//登录权限校验
static bool login_connection(THD *thd) {
    ...
    //登录握手流程
    error= check_connection(thd);
    //回复登录结果
    thd->send_statement_status();
    ...
}
//sql_connect.cc
static int check_connection(THD *thd) {
    ...
    auth_rc= acl_authenticate(thd, COM_CONNECT);
    ...
}
//sql_authentication.cc
int acl_authenticate(THD *thd, enum_server_command command) {
    ...
    res= do_auth_once(thd, auth_plugin_name, &mpvio);
    ...
}
//sql_authentication.cc
static int do_auth_once(THD *thd, const LEX_CSTRING &auth_plugin_name, MPVIO_EXT *mpvio) {
    ...
    //调用具体插件的校验函数,默认为mysql_native_password,调用函数native_password_authenticate
    res= auth->authenticate_user(mpvio, &mpvio->auth_info);
    ...
}
//sql_authentication.cc
//校验实现
static int native_password_authenticate(MYSQL_PLUGIN_VIO *vio, MYSQL_SERVER_AUTH_INFO *info) {
    //生成随机码字符串,长度为20
    generate_user_salt(mpvio->scramble, SCRAMBLE_LENGTH + 1);
    //发送随机码到客户端,调用函数server_mpvio_write_packet
    mpvio->write_packet(mpvio, (uchar*) mpvio->scramble, SCRAMBLE_LENGTH + 1)
    //接收客户端回复用户、密码等信息,调用函数server_mpvio_read_packet
    pkt_len= mpvio->read_packet(mpvio, &pkt)
    ...
    //校验密码
    check_scramble(pkt, mpvio->scramble, mpvio->acl_user->salt)
    ...
}
//sql_class.cc
void THD::send_statement_status() {
    ...
    //根据不同状态发送不同类型包
    error= m_protocol->send_error(da->mysql_errno(), da->message_text(), da->returned_sqlstate());
    error= m_protocol->send_eof(server_status, da->last_statement_cond_count());
    error= m_protocol->send_ok(server_status, da->last_statement_cond_count(), da->affected_rows(), da->last_insert_id(), da->message_text());
    ...
}

3、服务器发送随机码到客户端

根据登录流程调用链可以看到发送随机码由server_mpvio_write_packet实现,该函数发送数据格式如表所示。

字节 说明
1-3 数据包长度,小端序
1 数据包序列号,用于保证数据包的顺序
1 协议版本号,总是10
N 服务器版本号,以0结尾
4 服务器线程id,小端序
8 服务器生成的随机串前8个字节(随机串至少20字节)
1 0
2 服务器能力标志低2字节,小端序
1 服务器字符集,默认为latin1
2 服务器状态,小端序
2 服务器能力标志高2字节,小端序
1 随机串长度
10 保留,都是0
N 随机串剩余字节,至少12字节
1 0
N 插件名称,以0结尾

调用链:

//sql_authentication.cc
static int server_mpvio_write_packet(MYSQL_PLUGIN_VIO *param, const uchar *packet, int packet_len) {
    ...
    //发送数据包到客户端
    res= send_server_handshake_packet(mpvio, (char*) packet, packet_len);
    ...
}
//sql_authentication.cc
//该函数会按照指定格式发送数据包到客户端,格式如表1所示
static bool send_server_handshake_packet(MPVIO_EXT *mpvio, const char *data, uint data_len) {
...
//调用my_net_write发送数据
int res= protocol->write((uchar*) buff, (size_t) (end - buff + 1)) || protocol->flush_net();
...
}

4、客户端回复用户、密码

客户端接收到服务器发送的随机码数据后,会使用随机码加密密码,然后回复服务器,回包数据格式如表所示。

字节 说明
1-3 数据包长度,小端序
1 数据包序列号,用于保证数据包的顺序
4 客户端能力标志
4 最大数据包长度,小端序
1 字符集
23 保留,都是0
N 用户名,以0结尾
N 随机码加密后的密码:密文长度编码 + 密文
N 数据库名,以0结尾
N 插件名称,以0结尾

表中格式需要说明的是随机码加密后的密码,该数据由加密后的密文长度编码和密文组成。假设通过随机码加密后的密文为ens,密文长度为len,计算方式如下。

密文长度编码(见pack.c中函数net_store_length):

  • len<251,编码为len的1个字节;
  • 251<=len<65536,编码第一个字节为252,接着2个字节按照小端序存储len;
  • 65536<=len<16777216,编码第一个字节为253,接着3个字节按照小端序存储len;
  • len>=16777216,编码第一个字节为254,接着8个字节按照小端序存储len;
  • 第一个字节251为NULL保留。

密文(见password.c中函数scramble):

  • 计算密码的SHA1哈希值stage1;
  • 计算stage1的SHA1哈希值stage2;
  • 计算随机码和stage2的SHA1哈希值hash;
  • 将hash与stage1异或得到密文。

MySQL自带的客户端调用链:

//client.c
static int native_password_auth_client(MYSQL_PLUGIN_VIO *vio, MYSQL *mysql) {
    ...
    //由随机码加密密码
    scramble(scrambled, (char *)pkt, mysql->passwd);
    //客户端回复,调用函数client_mpvio_write_packet
    vio->write_packet(vio, (uchar *)scrambled, SCRAMBLE_LENGTH)
    ...
}
//client.c
static int client_mpvio_write_packet(struct st_plugin_vio *mpv, const uchar *pkt, int pkt_len) {
    ...
    res = send_client_reply_packet(mpvio, pkt, pkt_len);
    ...
}
//client.c
static int send_client_reply_packet(MCPVIO_EXT *mpvio, const uchar *data, int data_len) {
    ...
    //调用my_net_write发数据
    my_net_write(net, (uchar *)buff, (size_t)(end - buff)) || net_flush(net)
    ...
}

5、服务器校验密码

服务器收到客户端回包后,解析出用户、密码,并和数据库中的进行对比判断客户端登录合法性。

密码校验是通过对比密码的两阶段哈希值,即前一节介绍的生成密文的stage2,stage2是二进制字符串,数据库在存储时会转换为16进制的ASCII字符串,转换方式为:将每个字节拆分为2个字符:高4位和低4位,用两个16进制的0-9、A-F字符表示。例如二进制字符串:62、221,转换后为:3EDD。

校验调用链:

//sql_authentication.cc
static int server_mpvio_read_packet(MYSQL_PLUGIN_VIO *param, uchar **buf) {
    ...
    //调用my_net_read接收数据
    protocol->read_packet();
    ...
    //解析客户端回包
    pkt_len= parse_client_handshake_packet(mpvio, buf, pkt_len);
}
//sql_authentication.cc
static size_t parse_client_handshake_packet(MPVIO_EXT *mpvio, uchar **buff, size_t pkt_len) {
    ...
    //查找用户密码用于校验
    find_mpvio_user(mpvio)
    ...
}
//password.c
//校验密码,由check_scramble调用
my_bool check_scramble_sha1(const uchar *scramble_arg, const char *message, const uint8 *hash_stage2) {
    ...
    //计算随机码和数据库密文哈希值
    compute_sha1_hash_multi(buf, message, SCRAMBLE_LENGTH, (const char *) hash_stage2, SHA1_HASH_SIZE);
    //客户端密文和上面哈希值异或,获得客户端密码一阶段SHA1哈希值
    my_crypt((char *) buf, buf, scramble_arg, SCRAMBLE_LENGTH);
    //计算客户端密码二阶段SHA1哈希值
    compute_sha1_hash(hash_stage2_reassured, (const char *) buf, SHA1_HASH_SIZE);
    //对比客户端密码二阶段哈希值和数据密文
    return MY_TEST(memcmp(hash_stage2, hash_stage2_reassured, SHA1_HASH_SIZE));
}

6、模拟数据库登录

根据上面源码分析,我们可以写一个简单的登录逻辑:

package main

import (
    "bytes"
    "crypto/sha1"
    "fmt"
    "net"
    "time"
)

const (
    clientLongPassword = 1 << iota
    clientFoundRows
    clientLongFlag
    clientConnectWithDB
    clientNoSchema
    clientCompress
    clientODBC
    clientLocalFiles
    clientIgnoreSpace
    clientProtocol41
    clientInteractive
    clientSSL
    clientIgnoreSIGPIPE
    clientTransactions
    clientReserved
    clientSecureConn
    clientMultiStatements
    clientMultiResults
    clientPSMultiResults
    clientPluginAuth
    clientConnectAttrs
    clientPluginAuthLenEncClientData
    clientCanHandleExpiredPasswords
    clientSessionTrack
    clientDeprecateEOF
)

//第1-3字节:数据包长度,小端序
//接着1字节:数据包序号
//接着1字节:协议版本号,总是10
//接着N字节:服务器版本号,以0结尾
//接着4字节:服务器线程id,小端序
//接着8字节:服务器生成的随机串前8个字节(随机串至少20字节)
//接着1字节:0
//接着2字节:服务器能力标志低2字节,小端序
//接着1字节:服务器字符集,默认为latin1
//接着2字节:服务器状态,小端序
//接着2字节:服务器能力标志高2字节,小端序
//接着1字节:随机串长度
//接着10字节:保留,都是0
//接着N字节:随机串剩余字节,至少12字节
//接着1字节:0
//接着N字节:插件名称,以0结尾
func readHandshakePacket(conn net.Conn) ([]byte, string, error) {
    var builder bytes.Buffer
    buf := make([]byte, 1024)
    for {
        count, e := conn.Read(buf)
        if e != nil {
            fmt.Println("readHandshakePacket:", e)
            return nil, "", e
        }
        builder.Write(buf[0:count])
        if count < 1024 {
            break
        }
    }
    fmt.Println(builder.Len(), builder.Bytes())

    data := builder.Bytes()
    index := 0
    //数据包长度
    packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16)
    index += 3
    //数据包序号
    packetNum := int(data[index])
    index += 1
    //协议版本号
    protocolVer := int(data[index])
    index += 1
    //服务器版本
    var serverVer string
    for {
        if data[index] == 0 {
            break
        }
        serverVer += string(data[index])
        index += 1
    }
    index += 1
    //服务器线程号,假设服务器是小端序
    threadId := uint(data[index]) + uint(data[index+1]<<8) + uint(data[index+2]<<16) + uint(data[index+3]<<24)
    index += 4
    //随机串前8字节
    var firstRand string
    for i:=0; i<8; i++ {
        firstRand += string(data[index])
        index += 1
    }
    index += 1
    //服务器能力标志低2字节,假设服务器是小端序
    lowCapability := uint(data[index]) + uint(data[index+1]<<8)
    index += 2
    //服务器字符集
    charset := int(data[index])
    index += 1
    //服务器状态,假设服务器是小端序
    serverStatus := int(data[index]) + int(data[index+1]<<8)
    index += 2
    //服务器能力标志高2字节,假设服务器是小端序
    highCapability := uint(data[index]) + uint(data[index+1]<<8)
    index += 2
    //随机串长度
    randLen := int(data[index])
    index += 1
    //保留
    index += 10
    //随机串剩余字节
    var lastRand string
    for {
        if data[index] == 0 {
            break
        }
        lastRand += string(data[index])
        index += 1
    }
    index += 1
    //插件名称
    var pluginName string
    for {
        if data[index] == 0 {
            break
        }
        pluginName += string(data[index])
        index += 1
    }
    fmt.Println("数据包长度:", packetLen)
    fmt.Println("数据包序号:", packetNum)
    fmt.Println("协议版本号:", protocolVer)
    fmt.Println("服务器版本号:", serverVer)
    fmt.Println("服务器线程号:", threadId)
    fmt.Println("服务器能力标志:", lowCapability + highCapability<<16)
    fmt.Println("服务器字符集:", charset)
    fmt.Println("服务器状态:", serverStatus)
    fmt.Println("随机串长度:", randLen)
    fmt.Println("随机串:", firstRand + lastRand)
    fmt.Println("插件名称:", pluginName)
    return []byte(firstRand+lastRand), pluginName, nil
}

//第1-4字节:客户端能力标志
//接着4字节:最大数据包长度,小端序
//接着1字节:字符集
//接着23字节:保留,都是0
//接着N字节:用户名,以0结尾
//接着N字节:随机串加密后的密码:数据长度编码 + 数据
//接着N字节:数据库名,以0结尾
//接着N字节:插件名称,以0结尾
func writeHandshakePacket(conn net.Conn, scramble []byte, plugin string) {
    clientFlags := clientProtocol41 |
        clientSecureConn |
        clientLongPassword |
        clientTransactions |
        clientLocalFiles |
        clientPluginAuth |
        clientMultiResults |
        clientLongFlag |
        clientMultiStatements |
        clientConnectWithDB

    user := "xxx"
    password := "xxxx"
    dbName := "mysql"

    var builder bytes.Buffer
    //写入客户端能力标志
    builder.WriteByte(byte(clientFlags))
    builder.WriteByte(byte(clientFlags>>8))
    builder.WriteByte(byte(clientFlags>>16))
    builder.WriteByte(byte(clientFlags>>24))
    //最大数据包长度
    builder.Write([]byte{0x00, 0x00, 0x00, 0x00})
    //字符集
    builder.WriteByte(byte(0x08))
    //保留
    for i:=0; i<23; i++ {
        builder.WriteByte(byte(0x00))
    }
    //用户名
    builder.WriteString(user)
    builder.WriteByte(byte(0x00))
    //随机串加密后的密码:数据长度编码 + 数据
    authData := scramblePassword(scramble, password)
    builder.WriteByte(byte(len(authData)))
    builder.Write(authData)
    //数据库名
    builder.WriteString(dbName)
    builder.WriteByte(byte(0x00))
    //插件名称
    builder.WriteString(plugin)
    builder.WriteByte(byte(0x00))

    //发送数据,序列号需要在服务器发包基础上加1
    n := builder.Len()
    var data []byte
    data = append(data, byte(n), byte(n>>8), byte(n>>16), byte(0x01))
    data = append(data, builder.Bytes()...)
    _, _ = conn.Write(data)
    fmt.Println("发送:", data)
}

func scramblePassword(scramble []byte, password string) []byte {
    if len(password) == 0 {
        return nil
    }
    //stage1
    crypt := sha1.New()
    crypt.Write([]byte(password))
    stage1 := crypt.Sum(nil)
    //stage2
    crypt.Reset()
    crypt.Write(stage1)
    stage2 := crypt.Sum(nil)

    crypt.Reset()
    crypt.Write(scramble)
    crypt.Write(stage2)
    scramble = crypt.Sum(nil)
    fmt.Println("=======")
    fmt.Println(stage1, stage2, scramble)
    fmt.Println(string(stage1), string(stage2), string(scramble))

    for i := range scramble {
        scramble[i] ^= stage1[i]
    }
    return scramble
}

//第1-3字节:数据包长度,小端序
//接着1字节:数据包序号
//接着1字节:OK头,为0
//接着1-9字节:受影响行
//接着1-9字节:最后插入id
//接着2字节:服务器状态
//接着2字节:告警数量
func readAuthOkPacket(conn net.Conn) {
    var builder bytes.Buffer
    buf := make([]byte, 1024)
    for {
        count, e := conn.Read(buf)
        if e != nil {
            fmt.Println("readAuthOkPacket:", e)
            return
        }
        builder.Write(buf[0:count])
        if count < 1024 {
            break
        }
    }
    fmt.Println(builder.Len(), builder.Bytes())

    index := 0
    //数据包长度
    packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16)
    index += 3
    //数据包序号
    packetNum := int(data[index])
    index += 1
    //OK头
    index += 1
    //受影响行
    affectedRows, n := readLength(data[index:])
    index += n
    //最后插入id
    lastId, n := readLength(data[index:])
    index += n
    //服务器状态
    status := uint(data[index]) | uint(data[index+1]<<8)
    index += 2
    //告警数量
    warn := uint(data[index]) | uint(data[index+1]<<8)
    index += 2

    fmt.Println("登录成功")
    fmt.Println("数据包长度:", packetLen)
    fmt.Println("数据包序号:", packetNum)
    fmt.Println("受影响行:", affectedRows)
    fmt.Println("最后插入id:", lastId)
    fmt.Println("服务器状态:", status)
    fmt.Println("告警数量:", warn)
}

func readLength(data []byte) (uint64, int) {
    switch data[0] {
    case 0xfc: //252
        return uint64(data[1]) | uint64(data[2])<<8, 2
    case 0xfd: //253
        return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16), 3
    case 0xfe: //254
        return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16) | uint64(data[4]<<24) |
            uint64(data[5]<<32) | uint64(data[6]<<40) | uint64(data[7]<<48) | uint64(data[8]<<56), 4
    }
    //<251
    return uint64(data[0]), 1
}

// 1、服务器发送随机码到客户端
// 2、客户端发送加密后的密码到服务器
// 3、服务器检查密码
func login(addr string) {
    conn, e := net.Dial("tcp", addr)
    if e != nil {
        fmt.Println("login1:", e)
        return
    }
    defer conn.Close()
    scramble, plugin, e := readHandshakePacket(conn)
    if e != nil {
        fmt.Println("login2:", e)
        return
    }
    writeHandshakePacket(conn, scramble, plugin)
    readAuthOkPacket(conn)
    time.Sleep(3600*time.Second)
}

func main() {
    login("localhost:3306")
}

 

标签: 暂无
最后更新:2022年8月12日

jemuel

这个人很懒,什么都没留下

点赞
< 上一篇
下一篇 >

文章评论

您需要 登录 之后才可以评论
文章目录
  • 1、发送接收数据包流程
    • 1.1  发送数据包
    • 1.2 接收数据包
  • 2、登录流程
  • 3、服务器发送随机码到客户端
  • 4、客户端回复用户、密码
  • 5、服务器校验密码
  • 6、模拟数据库登录
最新 热点 随机
最新 热点 随机
Volcano源码分析系列—调度篇 K8S源码分析系列1—搭建K8S调试集群 K8S Controller开发 6.5840 Lab 1: MapReduce MongoDB源码分析系列1——编译环境搭建 GraphQL介绍及使用
Go channel源码分析 Golang优先级调度 Volcano源码分析系列—调度篇 MySQL源码分析系列2——启动流程 K8S源码分析系列1—搭建K8S调试集群 大数据平台之binlog采集方案

COPYRIGHT © 2021 www.miaozhouguang.com. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS

粤ICP备2022006024号

粤公网安备 44030602006568号