1、发送接收数据包流程
在介绍登录流程之前,让我们先看看mysql发送、接收数据包流程,mysql数据包有固定的协议格式,即每个数据包都包含一个4字节包头,其中前三个字节指定数据包大小,最后一个字节指定数据包序列号,序列号用于保证数据包的顺序,如下图所示:
1.1 发送数据包
//文件net_serv.cc
//发送逻辑数据包,将逻辑数据包按照大小0xffffff(16M)分割为一个或多个物理数据包,物理数据包增加数据头部,头部包括包长度、包序号。
my_bool my_net_write(NET *net, const uchar *packet, size_t len) {
...
net_write_buff(net, buff, NET_HEADER_SIZE)
net_write_buff(net, packet, z_size)
...
}
//文件net_serv.cc
//缓冲区物理数据包,缓冲区满自动发送,或者调用net_flush主动发送。
static my_bool net_write_buff(NET *net, const uchar *packet, size_t len) {
...
net_write_packet(net, net->buff, (size_t) (net->write_pos - net->buff) + left_length)
...
}
//文件net_serv.cc
//发送物理数据包。
my_bool net_write_packet(NET *net, const uchar *packet, size_t length) {
...
res= net_write_raw_loop(net, packet, length);
...
}
//文件net_serv.cc
//发送指定长度字节数据。
static my_bool net_write_raw_loop(NET *net, const uchar *buf, size_t count) {
...
size_t sentcnt= vio_write(net->vio, buf, count);
...
}
1.2 接收数据包
//文件net_serv.cc
//读取逻辑数据包,可能由多个物理包组成,通过物理包的头部包长度是否为最大值0xffffff判断是否有后继包。
ulong my_net_read(NET *net) {
...
len= net_read_packet(net, &complen);
...
}
//文件net_serv.cc
//读取一个物理数据包,调用net_read_packet_header读取数据包头。
static size_t net_read_packet(NET *net, size_t *complen) {
...
net_read_packet_header(net)
...
net_read_raw_loop(net, pkt_len)
...
}
//文件net_serv.cc
//读取物理数据包头部
static my_bool net_read_packet_header(NET *net) {
...
rc= net_read_raw_loop(net, count)
...
}
//文件net_serv.cc
//读取指定长度字节数据。
static my_bool net_read_raw_loop(NET *net, size_t count) {
...
size_t recvcnt= vio_read(net->vio, buf, count);
...
}
2、登录流程
登录流程主要包括:客户端连接服务器、服务器发送随机码到客户端、客户端发送用户密码到服务器、服务器返回校验结果,如下图所示。
登录流程调用链:
//sql_connect.cc
//登录权限校验
static bool login_connection(THD *thd) {
...
//登录握手流程
error= check_connection(thd);
//回复登录结果
thd->send_statement_status();
...
}
//sql_connect.cc
static int check_connection(THD *thd) {
...
auth_rc= acl_authenticate(thd, COM_CONNECT);
...
}
//sql_authentication.cc
int acl_authenticate(THD *thd, enum_server_command command) {
...
res= do_auth_once(thd, auth_plugin_name, &mpvio);
...
}
//sql_authentication.cc
static int do_auth_once(THD *thd, const LEX_CSTRING &auth_plugin_name, MPVIO_EXT *mpvio) {
...
//调用具体插件的校验函数,默认为mysql_native_password,调用函数native_password_authenticate
res= auth->authenticate_user(mpvio, &mpvio->auth_info);
...
}
//sql_authentication.cc
//校验实现
static int native_password_authenticate(MYSQL_PLUGIN_VIO *vio, MYSQL_SERVER_AUTH_INFO *info) {
//生成随机码字符串,长度为20
generate_user_salt(mpvio->scramble, SCRAMBLE_LENGTH + 1);
//发送随机码到客户端,调用函数server_mpvio_write_packet
mpvio->write_packet(mpvio, (uchar*) mpvio->scramble, SCRAMBLE_LENGTH + 1)
//接收客户端回复用户、密码等信息,调用函数server_mpvio_read_packet
pkt_len= mpvio->read_packet(mpvio, &pkt)
...
//校验密码
check_scramble(pkt, mpvio->scramble, mpvio->acl_user->salt)
...
}
//sql_class.cc
void THD::send_statement_status() {
...
//根据不同状态发送不同类型包
error= m_protocol->send_error(da->mysql_errno(), da->message_text(), da->returned_sqlstate());
error= m_protocol->send_eof(server_status, da->last_statement_cond_count());
error= m_protocol->send_ok(server_status, da->last_statement_cond_count(), da->affected_rows(), da->last_insert_id(), da->message_text());
...
}
3、服务器发送随机码到客户端
根据登录流程调用链可以看到发送随机码由server_mpvio_write_packet实现,该函数发送数据格式如表所示。
| 字节 | 说明 |
| 1-3 | 数据包长度,小端序 |
| 1 | 数据包序列号,用于保证数据包的顺序 |
| 1 | 协议版本号,总是10 |
| N | 服务器版本号,以0结尾 |
| 4 | 服务器线程id,小端序 |
| 8 | 服务器生成的随机串前8个字节(随机串至少20字节) |
| 1 | 0 |
| 2 | 服务器能力标志低2字节,小端序 |
| 1 | 服务器字符集,默认为latin1 |
| 2 | 服务器状态,小端序 |
| 2 | 服务器能力标志高2字节,小端序 |
| 1 | 随机串长度 |
| 10 | 保留,都是0 |
| N | 随机串剩余字节,至少12字节 |
| 1 | 0 |
| N | 插件名称,以0结尾 |
调用链:
//sql_authentication.cc
static int server_mpvio_write_packet(MYSQL_PLUGIN_VIO *param, const uchar *packet, int packet_len) {
...
//发送数据包到客户端
res= send_server_handshake_packet(mpvio, (char*) packet, packet_len);
...
}
//sql_authentication.cc
//该函数会按照指定格式发送数据包到客户端,格式如表1所示
static bool send_server_handshake_packet(MPVIO_EXT *mpvio, const char *data, uint data_len) {
...
//调用my_net_write发送数据
int res= protocol->write((uchar*) buff, (size_t) (end - buff + 1)) || protocol->flush_net();
...
}
4、客户端回复用户、密码
客户端接收到服务器发送的随机码数据后,会使用随机码加密密码,然后回复服务器,回包数据格式如表所示。
| 字节 | 说明 |
| 1-3 | 数据包长度,小端序 |
| 1 | 数据包序列号,用于保证数据包的顺序 |
| 4 | 客户端能力标志 |
| 4 | 最大数据包长度,小端序 |
| 1 | 字符集 |
| 23 | 保留,都是0 |
| N | 用户名,以0结尾 |
| N | 随机码加密后的密码:密文长度编码 + 密文 |
| N | 数据库名,以0结尾 |
| N | 插件名称,以0结尾 |
表中格式需要说明的是随机码加密后的密码,该数据由加密后的密文长度编码和密文组成。假设通过随机码加密后的密文为ens,密文长度为len,计算方式如下。
密文长度编码(见pack.c中函数net_store_length):
- len<251,编码为len的1个字节;
- 251<=len<65536,编码第一个字节为252,接着2个字节按照小端序存储len;
- 65536<=len<16777216,编码第一个字节为253,接着3个字节按照小端序存储len;
- len>=16777216,编码第一个字节为254,接着8个字节按照小端序存储len;
- 第一个字节251为NULL保留。
密文(见password.c中函数scramble):
- 计算密码的SHA1哈希值stage1;
- 计算stage1的SHA1哈希值stage2;
- 计算随机码和stage2的SHA1哈希值hash;
- 将hash与stage1异或得到密文。
MySQL自带的客户端调用链:
//client.c
static int native_password_auth_client(MYSQL_PLUGIN_VIO *vio, MYSQL *mysql) {
...
//由随机码加密密码
scramble(scrambled, (char *)pkt, mysql->passwd);
//客户端回复,调用函数client_mpvio_write_packet
vio->write_packet(vio, (uchar *)scrambled, SCRAMBLE_LENGTH)
...
}
//client.c
static int client_mpvio_write_packet(struct st_plugin_vio *mpv, const uchar *pkt, int pkt_len) {
...
res = send_client_reply_packet(mpvio, pkt, pkt_len);
...
}
//client.c
static int send_client_reply_packet(MCPVIO_EXT *mpvio, const uchar *data, int data_len) {
...
//调用my_net_write发数据
my_net_write(net, (uchar *)buff, (size_t)(end - buff)) || net_flush(net)
...
}
5、服务器校验密码
服务器收到客户端回包后,解析出用户、密码,并和数据库中的进行对比判断客户端登录合法性。
密码校验是通过对比密码的两阶段哈希值,即前一节介绍的生成密文的stage2,stage2是二进制字符串,数据库在存储时会转换为16进制的ASCII字符串,转换方式为:将每个字节拆分为2个字符:高4位和低4位,用两个16进制的0-9、A-F字符表示。例如二进制字符串:62、221,转换后为:3EDD。
校验调用链:
//sql_authentication.cc
static int server_mpvio_read_packet(MYSQL_PLUGIN_VIO *param, uchar **buf) {
...
//调用my_net_read接收数据
protocol->read_packet();
...
//解析客户端回包
pkt_len= parse_client_handshake_packet(mpvio, buf, pkt_len);
}
//sql_authentication.cc
static size_t parse_client_handshake_packet(MPVIO_EXT *mpvio, uchar **buff, size_t pkt_len) {
...
//查找用户密码用于校验
find_mpvio_user(mpvio)
...
}
//password.c
//校验密码,由check_scramble调用
my_bool check_scramble_sha1(const uchar *scramble_arg, const char *message, const uint8 *hash_stage2) {
...
//计算随机码和数据库密文哈希值
compute_sha1_hash_multi(buf, message, SCRAMBLE_LENGTH, (const char *) hash_stage2, SHA1_HASH_SIZE);
//客户端密文和上面哈希值异或,获得客户端密码一阶段SHA1哈希值
my_crypt((char *) buf, buf, scramble_arg, SCRAMBLE_LENGTH);
//计算客户端密码二阶段SHA1哈希值
compute_sha1_hash(hash_stage2_reassured, (const char *) buf, SHA1_HASH_SIZE);
//对比客户端密码二阶段哈希值和数据密文
return MY_TEST(memcmp(hash_stage2, hash_stage2_reassured, SHA1_HASH_SIZE));
}
6、模拟数据库登录
根据上面源码分析,我们可以写一个简单的登录逻辑:
package main
import (
"bytes"
"crypto/sha1"
"fmt"
"net"
"time"
)
const (
clientLongPassword = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
clientPSMultiResults
clientPluginAuth
clientConnectAttrs
clientPluginAuthLenEncClientData
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
)
//第1-3字节:数据包长度,小端序
//接着1字节:数据包序号
//接着1字节:协议版本号,总是10
//接着N字节:服务器版本号,以0结尾
//接着4字节:服务器线程id,小端序
//接着8字节:服务器生成的随机串前8个字节(随机串至少20字节)
//接着1字节:0
//接着2字节:服务器能力标志低2字节,小端序
//接着1字节:服务器字符集,默认为latin1
//接着2字节:服务器状态,小端序
//接着2字节:服务器能力标志高2字节,小端序
//接着1字节:随机串长度
//接着10字节:保留,都是0
//接着N字节:随机串剩余字节,至少12字节
//接着1字节:0
//接着N字节:插件名称,以0结尾
func readHandshakePacket(conn net.Conn) ([]byte, string, error) {
var builder bytes.Buffer
buf := make([]byte, 1024)
for {
count, e := conn.Read(buf)
if e != nil {
fmt.Println("readHandshakePacket:", e)
return nil, "", e
}
builder.Write(buf[0:count])
if count < 1024 {
break
}
}
fmt.Println(builder.Len(), builder.Bytes())
data := builder.Bytes()
index := 0
//数据包长度
packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16)
index += 3
//数据包序号
packetNum := int(data[index])
index += 1
//协议版本号
protocolVer := int(data[index])
index += 1
//服务器版本
var serverVer string
for {
if data[index] == 0 {
break
}
serverVer += string(data[index])
index += 1
}
index += 1
//服务器线程号,假设服务器是小端序
threadId := uint(data[index]) + uint(data[index+1]<<8) + uint(data[index+2]<<16) + uint(data[index+3]<<24)
index += 4
//随机串前8字节
var firstRand string
for i:=0; i<8; i++ {
firstRand += string(data[index])
index += 1
}
index += 1
//服务器能力标志低2字节,假设服务器是小端序
lowCapability := uint(data[index]) + uint(data[index+1]<<8)
index += 2
//服务器字符集
charset := int(data[index])
index += 1
//服务器状态,假设服务器是小端序
serverStatus := int(data[index]) + int(data[index+1]<<8)
index += 2
//服务器能力标志高2字节,假设服务器是小端序
highCapability := uint(data[index]) + uint(data[index+1]<<8)
index += 2
//随机串长度
randLen := int(data[index])
index += 1
//保留
index += 10
//随机串剩余字节
var lastRand string
for {
if data[index] == 0 {
break
}
lastRand += string(data[index])
index += 1
}
index += 1
//插件名称
var pluginName string
for {
if data[index] == 0 {
break
}
pluginName += string(data[index])
index += 1
}
fmt.Println("数据包长度:", packetLen)
fmt.Println("数据包序号:", packetNum)
fmt.Println("协议版本号:", protocolVer)
fmt.Println("服务器版本号:", serverVer)
fmt.Println("服务器线程号:", threadId)
fmt.Println("服务器能力标志:", lowCapability + highCapability<<16)
fmt.Println("服务器字符集:", charset)
fmt.Println("服务器状态:", serverStatus)
fmt.Println("随机串长度:", randLen)
fmt.Println("随机串:", firstRand + lastRand)
fmt.Println("插件名称:", pluginName)
return []byte(firstRand+lastRand), pluginName, nil
}
//第1-4字节:客户端能力标志
//接着4字节:最大数据包长度,小端序
//接着1字节:字符集
//接着23字节:保留,都是0
//接着N字节:用户名,以0结尾
//接着N字节:随机串加密后的密码:数据长度编码 + 数据
//接着N字节:数据库名,以0结尾
//接着N字节:插件名称,以0结尾
func writeHandshakePacket(conn net.Conn, scramble []byte, plugin string) {
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientLongFlag |
clientMultiStatements |
clientConnectWithDB
user := "xxx"
password := "xxxx"
dbName := "mysql"
var builder bytes.Buffer
//写入客户端能力标志
builder.WriteByte(byte(clientFlags))
builder.WriteByte(byte(clientFlags>>8))
builder.WriteByte(byte(clientFlags>>16))
builder.WriteByte(byte(clientFlags>>24))
//最大数据包长度
builder.Write([]byte{0x00, 0x00, 0x00, 0x00})
//字符集
builder.WriteByte(byte(0x08))
//保留
for i:=0; i<23; i++ {
builder.WriteByte(byte(0x00))
}
//用户名
builder.WriteString(user)
builder.WriteByte(byte(0x00))
//随机串加密后的密码:数据长度编码 + 数据
authData := scramblePassword(scramble, password)
builder.WriteByte(byte(len(authData)))
builder.Write(authData)
//数据库名
builder.WriteString(dbName)
builder.WriteByte(byte(0x00))
//插件名称
builder.WriteString(plugin)
builder.WriteByte(byte(0x00))
//发送数据,序列号需要在服务器发包基础上加1
n := builder.Len()
var data []byte
data = append(data, byte(n), byte(n>>8), byte(n>>16), byte(0x01))
data = append(data, builder.Bytes()...)
_, _ = conn.Write(data)
fmt.Println("发送:", data)
}
func scramblePassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
//stage1
crypt := sha1.New()
crypt.Write([]byte(password))
stage1 := crypt.Sum(nil)
//stage2
crypt.Reset()
crypt.Write(stage1)
stage2 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(scramble)
crypt.Write(stage2)
scramble = crypt.Sum(nil)
fmt.Println("=======")
fmt.Println(stage1, stage2, scramble)
fmt.Println(string(stage1), string(stage2), string(scramble))
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
//第1-3字节:数据包长度,小端序
//接着1字节:数据包序号
//接着1字节:OK头,为0
//接着1-9字节:受影响行
//接着1-9字节:最后插入id
//接着2字节:服务器状态
//接着2字节:告警数量
func readAuthOkPacket(conn net.Conn) {
var builder bytes.Buffer
buf := make([]byte, 1024)
for {
count, e := conn.Read(buf)
if e != nil {
fmt.Println("readAuthOkPacket:", e)
return
}
builder.Write(buf[0:count])
if count < 1024 {
break
}
}
fmt.Println(builder.Len(), builder.Bytes())
index := 0
//数据包长度
packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16)
index += 3
//数据包序号
packetNum := int(data[index])
index += 1
//OK头
index += 1
//受影响行
affectedRows, n := readLength(data[index:])
index += n
//最后插入id
lastId, n := readLength(data[index:])
index += n
//服务器状态
status := uint(data[index]) | uint(data[index+1]<<8)
index += 2
//告警数量
warn := uint(data[index]) | uint(data[index+1]<<8)
index += 2
fmt.Println("登录成功")
fmt.Println("数据包长度:", packetLen)
fmt.Println("数据包序号:", packetNum)
fmt.Println("受影响行:", affectedRows)
fmt.Println("最后插入id:", lastId)
fmt.Println("服务器状态:", status)
fmt.Println("告警数量:", warn)
}
func readLength(data []byte) (uint64, int) {
switch data[0] {
case 0xfc: //252
return uint64(data[1]) | uint64(data[2])<<8, 2
case 0xfd: //253
return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16), 3
case 0xfe: //254
return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16) | uint64(data[4]<<24) |
uint64(data[5]<<32) | uint64(data[6]<<40) | uint64(data[7]<<48) | uint64(data[8]<<56), 4
}
//<251
return uint64(data[0]), 1
}
// 1、服务器发送随机码到客户端
// 2、客户端发送加密后的密码到服务器
// 3、服务器检查密码
func login(addr string) {
conn, e := net.Dial("tcp", addr)
if e != nil {
fmt.Println("login1:", e)
return
}
defer conn.Close()
scramble, plugin, e := readHandshakePacket(conn)
if e != nil {
fmt.Println("login2:", e)
return
}
writeHandshakePacket(conn, scramble, plugin)
readAuthOkPacket(conn)
time.Sleep(3600*time.Second)
}
func main() {
login("localhost:3306")
}


文章评论