1、发送接收数据包流程
在介绍登录流程之前,让我们先看看mysql发送、接收数据包流程,mysql数据包有固定的协议格式,即每个数据包都包含一个4字节包头,其中前三个字节指定数据包大小,最后一个字节指定数据包序列号,序列号用于保证数据包的顺序,如下图所示:
1.1 发送数据包
//文件net_serv.cc //发送逻辑数据包,将逻辑数据包按照大小0xffffff(16M)分割为一个或多个物理数据包,物理数据包增加数据头部,头部包括包长度、包序号。 my_bool my_net_write(NET *net, const uchar *packet, size_t len) { ... net_write_buff(net, buff, NET_HEADER_SIZE) net_write_buff(net, packet, z_size) ... } //文件net_serv.cc //缓冲区物理数据包,缓冲区满自动发送,或者调用net_flush主动发送。 static my_bool net_write_buff(NET *net, const uchar *packet, size_t len) { ... net_write_packet(net, net->buff, (size_t) (net->write_pos - net->buff) + left_length) ... } //文件net_serv.cc //发送物理数据包。 my_bool net_write_packet(NET *net, const uchar *packet, size_t length) { ... res= net_write_raw_loop(net, packet, length); ... } //文件net_serv.cc //发送指定长度字节数据。 static my_bool net_write_raw_loop(NET *net, const uchar *buf, size_t count) { ... size_t sentcnt= vio_write(net->vio, buf, count); ... }
1.2 接收数据包
//文件net_serv.cc //读取逻辑数据包,可能由多个物理包组成,通过物理包的头部包长度是否为最大值0xffffff判断是否有后继包。 ulong my_net_read(NET *net) { ... len= net_read_packet(net, &complen); ... } //文件net_serv.cc //读取一个物理数据包,调用net_read_packet_header读取数据包头。 static size_t net_read_packet(NET *net, size_t *complen) { ... net_read_packet_header(net) ... net_read_raw_loop(net, pkt_len) ... } //文件net_serv.cc //读取物理数据包头部 static my_bool net_read_packet_header(NET *net) { ... rc= net_read_raw_loop(net, count) ... } //文件net_serv.cc //读取指定长度字节数据。 static my_bool net_read_raw_loop(NET *net, size_t count) { ... size_t recvcnt= vio_read(net->vio, buf, count); ... }
2、登录流程
登录流程主要包括:客户端连接服务器、服务器发送随机码到客户端、客户端发送用户密码到服务器、服务器返回校验结果,如下图所示。
登录流程调用链:
//sql_connect.cc //登录权限校验 static bool login_connection(THD *thd) { ... //登录握手流程 error= check_connection(thd); //回复登录结果 thd->send_statement_status(); ... } //sql_connect.cc static int check_connection(THD *thd) { ... auth_rc= acl_authenticate(thd, COM_CONNECT); ... } //sql_authentication.cc int acl_authenticate(THD *thd, enum_server_command command) { ... res= do_auth_once(thd, auth_plugin_name, &mpvio); ... } //sql_authentication.cc static int do_auth_once(THD *thd, const LEX_CSTRING &auth_plugin_name, MPVIO_EXT *mpvio) { ... //调用具体插件的校验函数,默认为mysql_native_password,调用函数native_password_authenticate res= auth->authenticate_user(mpvio, &mpvio->auth_info); ... } //sql_authentication.cc //校验实现 static int native_password_authenticate(MYSQL_PLUGIN_VIO *vio, MYSQL_SERVER_AUTH_INFO *info) { //生成随机码字符串,长度为20 generate_user_salt(mpvio->scramble, SCRAMBLE_LENGTH + 1); //发送随机码到客户端,调用函数server_mpvio_write_packet mpvio->write_packet(mpvio, (uchar*) mpvio->scramble, SCRAMBLE_LENGTH + 1) //接收客户端回复用户、密码等信息,调用函数server_mpvio_read_packet pkt_len= mpvio->read_packet(mpvio, &pkt) ... //校验密码 check_scramble(pkt, mpvio->scramble, mpvio->acl_user->salt) ... } //sql_class.cc void THD::send_statement_status() { ... //根据不同状态发送不同类型包 error= m_protocol->send_error(da->mysql_errno(), da->message_text(), da->returned_sqlstate()); error= m_protocol->send_eof(server_status, da->last_statement_cond_count()); error= m_protocol->send_ok(server_status, da->last_statement_cond_count(), da->affected_rows(), da->last_insert_id(), da->message_text()); ... }
3、服务器发送随机码到客户端
根据登录流程调用链可以看到发送随机码由server_mpvio_write_packet实现,该函数发送数据格式如表所示。
字节 | 说明 |
1-3 | 数据包长度,小端序 |
1 | 数据包序列号,用于保证数据包的顺序 |
1 | 协议版本号,总是10 |
N | 服务器版本号,以0结尾 |
4 | 服务器线程id,小端序 |
8 | 服务器生成的随机串前8个字节(随机串至少20字节) |
1 | 0 |
2 | 服务器能力标志低2字节,小端序 |
1 | 服务器字符集,默认为latin1 |
2 | 服务器状态,小端序 |
2 | 服务器能力标志高2字节,小端序 |
1 | 随机串长度 |
10 | 保留,都是0 |
N | 随机串剩余字节,至少12字节 |
1 | 0 |
N | 插件名称,以0结尾 |
调用链:
//sql_authentication.cc static int server_mpvio_write_packet(MYSQL_PLUGIN_VIO *param, const uchar *packet, int packet_len) { ... //发送数据包到客户端 res= send_server_handshake_packet(mpvio, (char*) packet, packet_len); ... } //sql_authentication.cc //该函数会按照指定格式发送数据包到客户端,格式如表1所示 static bool send_server_handshake_packet(MPVIO_EXT *mpvio, const char *data, uint data_len) { ... //调用my_net_write发送数据 int res= protocol->write((uchar*) buff, (size_t) (end - buff + 1)) || protocol->flush_net(); ... }
4、客户端回复用户、密码
客户端接收到服务器发送的随机码数据后,会使用随机码加密密码,然后回复服务器,回包数据格式如表所示。
字节 | 说明 |
1-3 | 数据包长度,小端序 |
1 | 数据包序列号,用于保证数据包的顺序 |
4 | 客户端能力标志 |
4 | 最大数据包长度,小端序 |
1 | 字符集 |
23 | 保留,都是0 |
N | 用户名,以0结尾 |
N | 随机码加密后的密码:密文长度编码 + 密文 |
N | 数据库名,以0结尾 |
N | 插件名称,以0结尾 |
表中格式需要说明的是随机码加密后的密码,该数据由加密后的密文长度编码和密文组成。假设通过随机码加密后的密文为ens,密文长度为len,计算方式如下。
密文长度编码(见pack.c中函数net_store_length):
- len<251,编码为len的1个字节;
- 251<=len<65536,编码第一个字节为252,接着2个字节按照小端序存储len;
- 65536<=len<16777216,编码第一个字节为253,接着3个字节按照小端序存储len;
- len>=16777216,编码第一个字节为254,接着8个字节按照小端序存储len;
- 第一个字节251为NULL保留。
密文(见password.c中函数scramble):
- 计算密码的SHA1哈希值stage1;
- 计算stage1的SHA1哈希值stage2;
- 计算随机码和stage2的SHA1哈希值hash;
- 将hash与stage1异或得到密文。
MySQL自带的客户端调用链:
//client.c static int native_password_auth_client(MYSQL_PLUGIN_VIO *vio, MYSQL *mysql) { ... //由随机码加密密码 scramble(scrambled, (char *)pkt, mysql->passwd); //客户端回复,调用函数client_mpvio_write_packet vio->write_packet(vio, (uchar *)scrambled, SCRAMBLE_LENGTH) ... } //client.c static int client_mpvio_write_packet(struct st_plugin_vio *mpv, const uchar *pkt, int pkt_len) { ... res = send_client_reply_packet(mpvio, pkt, pkt_len); ... } //client.c static int send_client_reply_packet(MCPVIO_EXT *mpvio, const uchar *data, int data_len) { ... //调用my_net_write发数据 my_net_write(net, (uchar *)buff, (size_t)(end - buff)) || net_flush(net) ... }
5、服务器校验密码
服务器收到客户端回包后,解析出用户、密码,并和数据库中的进行对比判断客户端登录合法性。
密码校验是通过对比密码的两阶段哈希值,即前一节介绍的生成密文的stage2,stage2是二进制字符串,数据库在存储时会转换为16进制的ASCII字符串,转换方式为:将每个字节拆分为2个字符:高4位和低4位,用两个16进制的0-9、A-F字符表示。例如二进制字符串:62、221,转换后为:3EDD。
校验调用链:
//sql_authentication.cc static int server_mpvio_read_packet(MYSQL_PLUGIN_VIO *param, uchar **buf) { ... //调用my_net_read接收数据 protocol->read_packet(); ... //解析客户端回包 pkt_len= parse_client_handshake_packet(mpvio, buf, pkt_len); } //sql_authentication.cc static size_t parse_client_handshake_packet(MPVIO_EXT *mpvio, uchar **buff, size_t pkt_len) { ... //查找用户密码用于校验 find_mpvio_user(mpvio) ... } //password.c //校验密码,由check_scramble调用 my_bool check_scramble_sha1(const uchar *scramble_arg, const char *message, const uint8 *hash_stage2) { ... //计算随机码和数据库密文哈希值 compute_sha1_hash_multi(buf, message, SCRAMBLE_LENGTH, (const char *) hash_stage2, SHA1_HASH_SIZE); //客户端密文和上面哈希值异或,获得客户端密码一阶段SHA1哈希值 my_crypt((char *) buf, buf, scramble_arg, SCRAMBLE_LENGTH); //计算客户端密码二阶段SHA1哈希值 compute_sha1_hash(hash_stage2_reassured, (const char *) buf, SHA1_HASH_SIZE); //对比客户端密码二阶段哈希值和数据密文 return MY_TEST(memcmp(hash_stage2, hash_stage2_reassured, SHA1_HASH_SIZE)); }
6、模拟数据库登录
根据上面源码分析,我们可以写一个简单的登录逻辑:
package main import ( "bytes" "crypto/sha1" "fmt" "net" "time" ) const ( clientLongPassword = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB clientNoSchema clientCompress clientODBC clientLocalFiles clientIgnoreSpace clientProtocol41 clientInteractive clientSSL clientIgnoreSIGPIPE clientTransactions clientReserved clientSecureConn clientMultiStatements clientMultiResults clientPSMultiResults clientPluginAuth clientConnectAttrs clientPluginAuthLenEncClientData clientCanHandleExpiredPasswords clientSessionTrack clientDeprecateEOF ) //第1-3字节:数据包长度,小端序 //接着1字节:数据包序号 //接着1字节:协议版本号,总是10 //接着N字节:服务器版本号,以0结尾 //接着4字节:服务器线程id,小端序 //接着8字节:服务器生成的随机串前8个字节(随机串至少20字节) //接着1字节:0 //接着2字节:服务器能力标志低2字节,小端序 //接着1字节:服务器字符集,默认为latin1 //接着2字节:服务器状态,小端序 //接着2字节:服务器能力标志高2字节,小端序 //接着1字节:随机串长度 //接着10字节:保留,都是0 //接着N字节:随机串剩余字节,至少12字节 //接着1字节:0 //接着N字节:插件名称,以0结尾 func readHandshakePacket(conn net.Conn) ([]byte, string, error) { var builder bytes.Buffer buf := make([]byte, 1024) for { count, e := conn.Read(buf) if e != nil { fmt.Println("readHandshakePacket:", e) return nil, "", e } builder.Write(buf[0:count]) if count < 1024 { break } } fmt.Println(builder.Len(), builder.Bytes()) data := builder.Bytes() index := 0 //数据包长度 packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16) index += 3 //数据包序号 packetNum := int(data[index]) index += 1 //协议版本号 protocolVer := int(data[index]) index += 1 //服务器版本 var serverVer string for { if data[index] == 0 { break } serverVer += string(data[index]) index += 1 } index += 1 //服务器线程号,假设服务器是小端序 threadId := uint(data[index]) + uint(data[index+1]<<8) + uint(data[index+2]<<16) + uint(data[index+3]<<24) index += 4 //随机串前8字节 var firstRand string for i:=0; i<8; i++ { firstRand += string(data[index]) index += 1 } index += 1 //服务器能力标志低2字节,假设服务器是小端序 lowCapability := uint(data[index]) + uint(data[index+1]<<8) index += 2 //服务器字符集 charset := int(data[index]) index += 1 //服务器状态,假设服务器是小端序 serverStatus := int(data[index]) + int(data[index+1]<<8) index += 2 //服务器能力标志高2字节,假设服务器是小端序 highCapability := uint(data[index]) + uint(data[index+1]<<8) index += 2 //随机串长度 randLen := int(data[index]) index += 1 //保留 index += 10 //随机串剩余字节 var lastRand string for { if data[index] == 0 { break } lastRand += string(data[index]) index += 1 } index += 1 //插件名称 var pluginName string for { if data[index] == 0 { break } pluginName += string(data[index]) index += 1 } fmt.Println("数据包长度:", packetLen) fmt.Println("数据包序号:", packetNum) fmt.Println("协议版本号:", protocolVer) fmt.Println("服务器版本号:", serverVer) fmt.Println("服务器线程号:", threadId) fmt.Println("服务器能力标志:", lowCapability + highCapability<<16) fmt.Println("服务器字符集:", charset) fmt.Println("服务器状态:", serverStatus) fmt.Println("随机串长度:", randLen) fmt.Println("随机串:", firstRand + lastRand) fmt.Println("插件名称:", pluginName) return []byte(firstRand+lastRand), pluginName, nil } //第1-4字节:客户端能力标志 //接着4字节:最大数据包长度,小端序 //接着1字节:字符集 //接着23字节:保留,都是0 //接着N字节:用户名,以0结尾 //接着N字节:随机串加密后的密码:数据长度编码 + 数据 //接着N字节:数据库名,以0结尾 //接着N字节:插件名称,以0结尾 func writeHandshakePacket(conn net.Conn, scramble []byte, plugin string) { clientFlags := clientProtocol41 | clientSecureConn | clientLongPassword | clientTransactions | clientLocalFiles | clientPluginAuth | clientMultiResults | clientLongFlag | clientMultiStatements | clientConnectWithDB user := "xxx" password := "xxxx" dbName := "mysql" var builder bytes.Buffer //写入客户端能力标志 builder.WriteByte(byte(clientFlags)) builder.WriteByte(byte(clientFlags>>8)) builder.WriteByte(byte(clientFlags>>16)) builder.WriteByte(byte(clientFlags>>24)) //最大数据包长度 builder.Write([]byte{0x00, 0x00, 0x00, 0x00}) //字符集 builder.WriteByte(byte(0x08)) //保留 for i:=0; i<23; i++ { builder.WriteByte(byte(0x00)) } //用户名 builder.WriteString(user) builder.WriteByte(byte(0x00)) //随机串加密后的密码:数据长度编码 + 数据 authData := scramblePassword(scramble, password) builder.WriteByte(byte(len(authData))) builder.Write(authData) //数据库名 builder.WriteString(dbName) builder.WriteByte(byte(0x00)) //插件名称 builder.WriteString(plugin) builder.WriteByte(byte(0x00)) //发送数据,序列号需要在服务器发包基础上加1 n := builder.Len() var data []byte data = append(data, byte(n), byte(n>>8), byte(n>>16), byte(0x01)) data = append(data, builder.Bytes()...) _, _ = conn.Write(data) fmt.Println("发送:", data) } func scramblePassword(scramble []byte, password string) []byte { if len(password) == 0 { return nil } //stage1 crypt := sha1.New() crypt.Write([]byte(password)) stage1 := crypt.Sum(nil) //stage2 crypt.Reset() crypt.Write(stage1) stage2 := crypt.Sum(nil) crypt.Reset() crypt.Write(scramble) crypt.Write(stage2) scramble = crypt.Sum(nil) fmt.Println("=======") fmt.Println(stage1, stage2, scramble) fmt.Println(string(stage1), string(stage2), string(scramble)) for i := range scramble { scramble[i] ^= stage1[i] } return scramble } //第1-3字节:数据包长度,小端序 //接着1字节:数据包序号 //接着1字节:OK头,为0 //接着1-9字节:受影响行 //接着1-9字节:最后插入id //接着2字节:服务器状态 //接着2字节:告警数量 func readAuthOkPacket(conn net.Conn) { var builder bytes.Buffer buf := make([]byte, 1024) for { count, e := conn.Read(buf) if e != nil { fmt.Println("readAuthOkPacket:", e) return } builder.Write(buf[0:count]) if count < 1024 { break } } fmt.Println(builder.Len(), builder.Bytes()) index := 0 //数据包长度 packetLen := int(data[index]) + int(data[index+1]<<8) + int(data[index+2]<<16) index += 3 //数据包序号 packetNum := int(data[index]) index += 1 //OK头 index += 1 //受影响行 affectedRows, n := readLength(data[index:]) index += n //最后插入id lastId, n := readLength(data[index:]) index += n //服务器状态 status := uint(data[index]) | uint(data[index+1]<<8) index += 2 //告警数量 warn := uint(data[index]) | uint(data[index+1]<<8) index += 2 fmt.Println("登录成功") fmt.Println("数据包长度:", packetLen) fmt.Println("数据包序号:", packetNum) fmt.Println("受影响行:", affectedRows) fmt.Println("最后插入id:", lastId) fmt.Println("服务器状态:", status) fmt.Println("告警数量:", warn) } func readLength(data []byte) (uint64, int) { switch data[0] { case 0xfc: //252 return uint64(data[1]) | uint64(data[2])<<8, 2 case 0xfd: //253 return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16), 3 case 0xfe: //254 return uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3]<<16) | uint64(data[4]<<24) | uint64(data[5]<<32) | uint64(data[6]<<40) | uint64(data[7]<<48) | uint64(data[8]<<56), 4 } //<251 return uint64(data[0]), 1 } // 1、服务器发送随机码到客户端 // 2、客户端发送加密后的密码到服务器 // 3、服务器检查密码 func login(addr string) { conn, e := net.Dial("tcp", addr) if e != nil { fmt.Println("login1:", e) return } defer conn.Close() scramble, plugin, e := readHandshakePacket(conn) if e != nil { fmt.Println("login2:", e) return } writeHandshakePacket(conn, scramble, plugin) readAuthOkPacket(conn) time.Sleep(3600*time.Second) } func main() { login("localhost:3306") }
文章评论