如何用 Go 實現 JIT Compiler

原文作者是 Sidhartha Mani 首發於 Medium, 曾由 jiangwei161002010 翻譯後發佈在 Go 語言中文網 [1]

https://medium.com/kokster/writing-a-jit-compiler-in-golang-964b61295f

JIT(Just-In-Time[2]) Compiler 是指在運行期,實時生成機器碼的程序。機器碼 (machine code) 是 cpu 識別的機器指令

一般我們都是 Go 代碼翻譯成彙編,然後轉換成 cpu 可識別的機器碼。相比於其它代碼 (fmt.Printf), Jit Compiler 是在運行期生成,而不是編譯期 (即我們常用的 go build)

Go 是靜態類型,編譯完後生成二進制可執行文件。看起來不可能生成任意代碼,更不用說去執行這些代碼。然而,向運行中的程序注入指令是可行的,這就是所謂的 Type Magic, 類型魔法,可以將任意類型轉換成其它類型

生成 x64 機器碼

機器碼是一串字節數據,對於 cpu 有特殊的意義 (識別後執行)。本文用來測試的是 x86 處理器,因此,採用 x64 instruction set[3] 指令集,平臺相關,所以你在非 x86 上肯定無法運行

本文測試 Jit 讓我們用 x86 指令,打印 Hello World!

write(int fd, const void *buf, size_t count)

衆所周知,write 是系統調用,fd 是輸出的文件描述符,標準輸出是 1, 第二個參數是 Hello World! 字符串指針,第三個參數 count 是要打印的字符數 12

同時我們知道,系統調用是 Syscall 指令,rax 告訴 linux 要執行的具體系統調用函數編號 [4],write 是 1, 在 C 語言中函數的參數由 6 個寄存器傳遞,由於 write 只有三個參數,所以第一個參數 rdi 是 1 (文件描述符),第二個參數 rsi 是字符串地址,但是暫時無法確定值,稍後咱們再看。第三個參數 rdx 是字符串長度 12

0:  48 c7 c0 01 00 00 00  mov rax,0x1 
7:  48 c7 c7 01 00 00 00  mov rdi,0x1
e:  48 c7 c2 0c 00 00 00  mov rdx,0xc

放在一起,上面就是對應的機器碼,右面是彙編代碼 (網上有很多,根據彙編指令生成機器碼的,大家不用過於考究,知道意思就行)。那麼現在唯一的問題在於,如何確定字符串指針的地址

這個地址一定是運行時有效的,否則就會發生 segmentfault 錯誤。這個例子中,我們可以把 Hello world! 字符串放到可執行的命令 return 的後面。這是安全的,因爲 cpu 執行完就返回了,不會走到後面

由於在返回指令下達之前不知道返回後的地址,所以可以使用一個臨時的位置佔位符 placeholder, 一旦知道了數據的地址,就用正確的地址來代替。這就是鏈接器所遵循的確切程序。鏈接的過程只是將這些地址填入正確的數據或函數 (稍微瞭解 gcc 編譯原理的會很清楚,這裏面說不知道返回後地址,是不知道 lea 與最後的 ret 中間還有多少指令,所以無法確定相對地址)

15: 48 8d 35 00 00 00 00 lea rsi,[rip+0x0] # 0x15
1c: 0f 05                syscall
1e: c3                   ret

上面的代碼中,lea 指令是用來取字符串 Hello World! 地址的,但是這裏是當前的位置 (我們知道 rip 寄存器代表當前執行代碼的地址,0x0 是偏移量)。爲什麼指向自己呢?因爲字符串還沒有保存呢,不知道具體位置。其中 0F 05Syscall 對應的機器碼

1f: 48 65 6c 6c 6f 20 57 6f 72 6c 64 21   // Hello World!

我們現在把字符串放到 ret 指令後面,那麼此時字符串的相對地址就確定了

0:  48 c7 c0 01 00 00 00 mov rax,0x1
7:  48 c7 c7 01 00 00 00 mov rdi,0x1
e:  48 c7 c2 0c 00 00 00 mov rdx,0xc
15: 48 8d 35 03 00 00 00 lea rsi,[rip+0x3]# 0x1f
1c: 0f 05                syscall
1e: c3                   ret
1f: 48 65 6c 6c 6f 20 57 6f 72 6c 64 21   // Hello World!

然後爲了可讀性我們用 golang int16 slice 保存上面的機器碼

printFunction := []uint16{
0x48c7, 0xc001, 0x0, // mov %rax,$0x1
0x48, 0xc7c7, 0x100, 0x0, // mov %rdi,$0x1
0x48c7, 0xc20c, 0x0, // mov 0x13, %rdx
0x48, 0x8d35, 0x400, 0x0,// lea 0x4(%rip), %rsi
0xf05,// syscall
0xc3cc,// ret
0x4865, 0x6c6c, 0x6f20,// Hello_(whitespace)
0x576f, 0x726c, 0x6421, 0xa,// World!
}

切片保存的和上面機器碼有些偏差,是爲了對齊,更好看一些 (ret 機器碼是 c3, 後面多了一個 cc, 這是一個 no-op 指令**)。所以 lea 取地址即爲 rip+0x4**

將 slice 轉換成函數

切片中的指令,必須轉換成函數才能調用,下面的 go 代碼展示瞭如何轉換

type printFunc func()
unsafePrintFunc := (uintptr)(unsafe.Pointer(&printFunction)) 
printer := *(*printFunc)(unsafe.Pointer(&unsafePrintFunc)) 
printer()

Go 函數值只是一個指向 C 函數指針的指針(注意兩級指針。從切片到函數的轉換,首先要提取一個指向存放可執行代碼的數據結構的指針(這裏就是 slice 的指針)。該指針被存儲在 unsafePrintFunc 中。指向 unsafePrintFunc 的指針可以被轉換成所需要的函數類型

這種方法只適用於沒有參數或返回值的函數。在調用有參數或返回值的函數時,需要提前創建一個棧空(go 是 stack-based 函數調用規約,以後會改成 register-bassed, 感興趣的網上搜一下

讓函數執行

上述函數實際上不會運行。這是因爲 Go 將所有的數據存儲在二進制文件的 data 段。這一部分的數據被設置了 No-Execute 標誌,防止它被執行 (這不只是 Go, 所有都是這樣)

printFunction 片段中的數據需要存儲在一塊可執行的內存中。這可以通過移除 printFunction 片斷上的 No-Execute 標誌或將其複製到可執行的內存位置來實現

在下面的代碼中,數據被複制到一個新分配的可執行的內存中(使用 mmap)。這種方法比較好,因爲只有在整個頁面上才能設置不執行標誌 -- 很容易無意中使數據部分的其他部分成爲可執行的,變得很不安全 (mmap 就安全嘛?表示懷疑)

 executablePrintFunc, err := syscall.Mmap(
  -1,
  0,
  128,
  syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC,
  syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS)
 if err != nil {
  fmt.Printf("mmap err: %v", err)
 }
j := 0
 for i := range printFunction {
  executablePrintFunc[j] = byte(printFunction[i] >> 8)
  executablePrintFunc[j+1] = byte(printFunction[i])
  j = j + 2
 }

上面代碼用 Mmap 創建了一塊匿名的可執行私有內存區域,然後把 printFunction 代碼按照機器碼順序複製到該區域。重點是 syscall.PROT_EXEC 標記

package main

import (
 "fmt"
 "syscall"
 "unsafe"
)

type printFunc func()

func main() {
 printFunction := []uint16{
  0x48c7, 0xc001, 0x0, // mov %rax,$0x1
  0x48, 0xc7c7, 0x100, 0x0, // mov %rdi,$0x1
  0x48c7, 0xc20c, 0x0, // mov 0x13, %rdx
  0x48, 0x8d35, 0x400, 0x0, // lea 0x4(%rip), %rsi
  0xf05,                  // syscall
  0xc3cc,                 // ret
  0x4865, 0x6c6c, 0x6f20, // Hello_(whitespace)
  0x576f, 0x726c, 0x6421, 0xa, // World!
 }
 executablePrintFunc, err := syscall.Mmap(
  -1,
  0,
  128,
  syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC,
  syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS)
 if err != nil {
  fmt.Printf("mmap err: %v", err)
 }

 j := 0
 for i := range printFunction {
  executablePrintFunc[j] = byte(printFunction[i] >> 8)
  executablePrintFunc[j+1] = byte(printFunction[i])
  j = j + 2
 }

 type printFunc func()

 unsafePrintFunc := (uintptr)(unsafe.Pointer(&executablePrintFunc))
 printer := *(*printFunc)(unsafe.Pointer(&unsafePrintFunc))
 printer()
}

上面就是最終可執行的代碼,copy 時要考濾 little 小端對齊,執行後打印 Hello World!

~# strace ./jit
execve("./jit"["./jit"], 0x7ffe27f7ef70 /* 17 vars */) = 0
arch_prctl(ARCH_SET_FS, 0x54e5d0)       = 0
sched_getaffinity(0, 8192, [0, 1])      = 8

......

readlinkat(AT_FDCWD, "/proc/self/exe""/root/jit", 128) = 9
fcntl(0, F_GETFL)                       = 0x2 (flags O_RDWR)
futex(0xc000036950, FUTEX_WAKE_PRIVATE, 1) = 1
fcntl(1, F_GETFL)                       = 0x2 (flags O_RDWR)
fcntl(2, F_GETFL)                       = 0x2 (flags O_RDWR)
mmap(NULL, 128, PROT_READ|PROT_WRITE|PROT_EXEC, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0) = 0x7f5eabb00000
write(1, "Hello World!", 12Hello World!)            = 12
exit_group(0)                           = ?
+++ exited with 0 +++

strace 可以看到完整系統調用,先是 mmap 然後打印字符串到標準輸出

小結

今天的分享就這些,祝大家玩的開心!!JIT 是一個非常龐大的 topic, 還是蠻有意思的。寫文章不容易,如果對大家有所幫助和啓發,請大家幫忙點擊在看點贊分享 三連

關於 Go JIT 大家有什麼看法,歡迎留言一起討論,大牛多留言,下一篇分享生命週期 ^_^

參考資料

[1]

Go 語言中文網: https://studygolang.com/articles/12730,

[2]

Just-In-Time: https://en.wikipedia.org/wiki/Just-in-time_compilation,

[3]

x86 instruction set: https://software.intel.com/content/www/us/en/develop/articles/introduction-to-x64-assembly.html,

[4]

系統調用函數編號: https://filippo.io/linux-syscall-table/https://filippo.io/linux-syscall-table/,

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/ubxfUg9qAqlBmi-Yw0H-fw