如何用 go 實現一個 ORM

本期作者

洪勝傑

B 端技術中心高級開發工程師

爲了提高開發效率和質量,我們常常需要 ORM 來幫助我們快速實現持久層增刪改查 API,目前 go 語言實現的 ORM 有很多種,他們都有自己的優劣點,有的實現簡單,有的功能複雜,有的 API 十分優雅。在使用了多個類似的工具之後,總是會發現某些點無法滿足解決我們生產環境中碰到的實際問題,比如無法集成公司內部的監控,Trace 組件,沒有 database 層的超時設置,沒有熔斷等,所以有必要公司自己內部實現一款滿足我們可自定義開發的 ORM,好用的生產工具常常能夠對生產力產生飛躍式的提升。

爲什麼需要 ORM

直接使用 database/sql 的痛點

首先看看用 database/sql 如何查詢數據庫
我們用 user 表來做例子,一般的工作流程是先做技術方案,其中排在比較前面的是數據庫表的設計,大部分公司應該有嚴格的數據庫權限控制,不會給線上程序使用比較危險的操作權限,比如創建刪除數據庫,表,刪除數據等。
表結構如下:

CREATE TABLE `user` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'id',
  `name` varchar(100) NOT NULL COMMENT '名稱',
  `age` int(11) NOT NULL DEFAULT '0' COMMENT '年齡',
  `ctime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '創建時間',
  `mtime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '更新時間',
  PRIMARY KEY (`id`),
) ENGINE=InnoDB  DEFAULT CHARSET=utf8mb4

首先我們要寫出和表結構對應的結構體 User,如果你足夠勤奮和努力,相應的 json tag 和註釋都可以寫上,這個過程無聊且重複,因爲在設計表結構的時候你已經寫過一遍了。

type User struct {
    Id    int64     `json:"id"`   
    Name  string    `json:"name"`
    Age   int64    
    Ctime time.Time
    Mtime time.Time  // 更新時間
}

定義好結構體,我們寫一個查詢年齡在 20 以下且按照 id 字段順序排序的前 20 名用戶的 go 代碼

func FindUsers(ctx context.Context) ([]*User, error) {
    rows, err := db.QueryContext(ctx, "SELECT `id`,`name`,`age`,`ctime`,`mtime` FROM user WHERE `age`<? ORDER BY `id` LIMIT 20 ", 20)
    if err != nil {
        return nil, err
    }
    defer rows.Close()
    result := []*User{}
    for rows.Next() {
        a := &User{}
        if err := rows.Scan(&a.Id, &a.Name, &a.Age, &a.Ctime, &a.Mtime); err != nil {
            return nil, err
        }
        result = append(result, a)
    }
    if rows.Err() != nil {
        return nil, rows.Err()
    }
    return result, nil
}

當我們寫少量這樣的代碼的時候我們可能還覺得輕鬆,但是當你業務工期排的很緊,並且要寫大量的定製化查詢的時候,這樣的重複代碼會越來越多。
上面的的代碼我們發現有這麼幾個問題:

  1. SQL 語句是硬編碼在程序裏面的,當我需要增加查詢條件的時候我需要另外再寫一個方法,整個方法需要拷貝一份,很不靈活。

  2. 在查詢表所有字段的情況下,第 2 行下面的代碼都是一樣重複的,不管 sql 語句後面的條件是怎麼樣的。

  3. 我們發現第 1 行 SQL 語句編寫和 rows.Scan() 那行,寫的枯燥層度是和表字段的數量成正比的,如果一個表有 50 個字段或者 100 個字段,手寫是非常乏味的。

  4. 在開發過程中 rows.Close() 和 rows.Err() 忘記寫是常見的錯誤。

我們總結出來用 database/sql 標準庫開發的痛點:

開發效率很低

很顯然寫上面的那種代碼是很耗費時間的,因爲手誤容易寫錯,無可避免要增加自測的時間。如果上面的結構體 User、 查詢方法 FindUsers() 代碼能夠自動生成,那麼那將會極大的提高開發效率並且減少 human error 的發生從而提高開發質量。

心智負擔很重

如果一個開發人員把大量的時間花在這些代碼上,那麼他其實是在浪費自己的時間,不管在工作中還是在個人項目中,應該把重點花在架構設計,業務邏輯設計,困難點攻堅上面,去探索和開拓自己沒有經驗的領域,這塊 Dao 層的代碼最好在 10 分鐘內完成。

ORM 的核心組成

明白了上面的痛點,爲了開發工作更舒服,更高效,我們嘗試着自己去開發一個 ORM,核心的地方在於兩個方面:

  1. SQLBuilder:SQL 語句要非硬編碼,通過某種鏈式調用構造器幫助我構建 SQL 語句。

  2. Scanner:從數據庫返回的數據可以自動映射賦值到結構體中。

SQL SelectBuilder

我們嘗試做個簡略版的查詢語句構造器, 最終我們要達到如下圖所示的效果。

我們可以通過和 SQL 關鍵字同名的方法來表達 SQL 語句的固有關鍵字,通過 go 方法參數來設置其中動態變化的元素,這樣鏈式調用和寫 SQL 語句的思維順序是一致的,只不過我們之前通過硬編碼的方式變成了方法調用。

具體代碼如下:

type SelectBuilder struct {
    builder   *strings.Builder
    column    []string
    tableName string
    where     []func(s *SelectBuilder)
    args      []interface{}
    orderby   string
    offset    *int64
    limit     *int64
}
func (s *SelectBuilder) Select(field ...string) *SelectBuilder {
    s.column = append(s.column, field...)
    return s
}
func (s *SelectBuilder) From(name string) *SelectBuilder {
    s.tabelName = name
    return s
}
func (s *SelectBuilder) Where(f ...func(s *SelectBuilder)) *SelectBuilder {
    s.where = append(s.where, f...)
    return s
}
func (s *SelectBuilder) OrderBy(field string) *SelectBuilder {
    s.orderby = field
    return s
}
func (s *SelectBuilder) Limit(offset, limit int64) *SelectBuilder {
    s.offset = &offset
    s.limit = &limit
    return s
}
func GT(field string, arg interface{}) func(s *SelectBuilder) {
    return func(s *SelectBuilder) {
        s.builder.WriteString("`" + field + "`" + " > ?")
        s.args = append(s.args, arg)
    }
}
func (s *SelectBuilder) Query() (string, []interface{}) {
    s.builder.WriteString("SELECT ")
    for k, v := range s.column {
        if k > 0 {
            s.builder.WriteString(",")
        }
        s.builder.WriteString("`" + v + "`")
    }
    s.builder.WriteString(" FROM ")
    s.builder.WriteString("`" + s.tableName + "` ")
    if len(s.where) > 0 {
        s.builder.WriteString("WHERE ")
        for k, f := range s.where {
            if k > 0 {
                s.builder.WriteString(" AND ")
            }
            f(s)
        }
    }
    if s.orderby != "" {
        s.builder.WriteString(" ORDER BY " + s.orderby)
    }
    if s.limit != nil {
        s.builder.WriteString(" LIMIT ")
        s.builder.WriteString(strconv.FormatInt(*s.limit, 10))
    }
    if s.offset != nil {
        s.builder.WriteString(" OFFSET ")
        s.builder.WriteString(strconv.FormatInt(*s.offset, 10))
    }
    return s.builder.String(), s.args
}
  1. 通過結構體上的方法調用返回自身,使其具有鏈式調用能力,並通過方法調用設置結構體中的值,用以構成 SQL 語句需要的元素。

  2. SelectBuilder 包含性能較高的 strings.Builder 來拼接字符串。

  3. Query() 方法構建出真正的 SQL 語句,返回包含佔位符的 SQL 語句和 args 參數。

  4. []func(s *SelectBuilder) 通過函數數組來創建查詢條件,可以通過函數調用的順序和層級來生成 AND OR 這種有嵌套關係的查詢條件子句。

  5. Where() 傳入的是查詢條件函數,爲可變參數列表,查詢條件之間默認是 AND 關係。

外部使用起來效果:

b := SelectBuilder{builder: &strings.Builder{}}
sql, args := b.
    Select("id", "name", "age", "ctime", "mtime").
    From("user").
    Where(GT("id", 0), GT("age", 0)).
    OrderBy("id").
    Limit(0, 20).
    Query()

Scanner 的實現

顧名思義 Scanner 的作用就是把查詢結果設置到對應的 go 對象上去,完成關係和對象的映射,關鍵核心就是通過反射獲知傳入對象的類型和字段類型,通過反射創建對象和值,並通過 golang 結構體的字段後面的 tag 來和查詢結果的表頭一一對應,達到動態給結構字段賦值的能力。

具體實現如下:

func ScanSlice(rows *sql.Rows, dst interface{}) error {
    defer rows.Close()
    // dst的地址
    val := reflect.ValueOf(dst) //  &[]*main.User
    // 判斷是否是指針類型,go是值傳遞,只有傳指針才能讓更改生效
    if val.Kind() != reflect.Ptr {
        return errors.New("dst not a pointer")
    }
    // 指針指向的Value
    val = reflect.Indirect(val) // []*main.User
    if val.Kind() != reflect.Slice {
        return errors.New("dst not a pointer to slice")
    }
    // 獲取slice中的類型
    struPointer := val.Type().Elem() // *main.User
    // 指針指向的類型 具體結構體
    stru := struPointer.Elem()      //  main.User
    cols, err := rows.Columns()  // [id,name,age,ctime,mtime]
    if err != nil {
        return err
    }
    // 判斷查詢的字段數是否大於 結構體的字段數
    if stru.NumField() < len(cols) { // 5,5
        return errors.New("NumField and cols not match")
    }
    //結構體的json tag的value對應字段在結構體中的index
    tagIdx := make(map[string]int) //map tag -> field idx
    for i := 0; i < stru.NumField(); i++ {
        tagname := stru.Field(i).Tag.Get("json")
        if tagname != "" {
            tagIdx[tagname] = i
        }
    }
    resultType := make([]reflect.Type, 0, len(cols)) // [int64,string,int64,time.Time,time.Time]
    index := make([]int, 0, len(cols))               // [0,1,2,3,4,5]
    // 查找和列名相對應的結構體jsontag name的字段類型,保存類型和序號到resultType和index中
    for _, v := range cols {
        if i, ok := tagIdx[v]; ok {
            resultType = append(resultType, stru.Field(i).Type)
            index = append(index, i)
        }
    }
    for rows.Next() {
        // 創建結構體指針,獲取指針指向的對象
        obj := reflect.New(stru).Elem()                   // main.User
        result := make([]interface{}, 0, len(resultType)) //[]
        // 創建結構體字段類型實例的指針,並轉化爲interface{} 類型
        for _, v := range resultType {
            result = append(result, reflect.New(v).Interface()) // *Int64 ,*string ....
        }
        // 掃描結果
        err := rows.Scan(result...)
        if err != nil {
            return err
        }
        for i, v := range result {
            // 找到對應的結構體index
            fieldIndex := index[i]
            // 把scan 後的值通過反射得到指針指向的value,賦值給對應的結構體字段
            obj.Field(fieldIndex).Set(reflect.ValueOf(v).Elem()) // 給obj 的每個字段賦值
        }
        // append 到slice
        vv := reflect.Append(val, obj.Addr()) // append到 []*main.User, maybe addr change
        val.Set(vv)                           // []*main.User
    }
    return rows.Err()
}

通過反射賦值流程,如果想知道具體的實現細節可以仔細閱讀上面代碼裏面的註釋

  1. 以上主要的思想就是通過 reflect 包來獲取傳入 dst 的 Slice 類型,並通過反射創建其包含的對象,具體的步驟和解釋請仔細閱讀註釋和圖例。

  2. 通過指定的 json tag 可以把查詢結果和結構體字段 mapping 起來,即使查詢語句中字段不按照表結構順序。

  3. ScanSlice 是通用的 Scanner。

  4. 使用反射創建對象明顯創建了多餘的對象,沒有傳統的方式賦值高效,但是換來的巨大的靈活性在某些場景下是值得的。

有了 SQLBuilder 和 Scanner 我們就可以這樣寫查詢函數了:

func FindUserReflect() ([]*User, error) {
    b := SelectBuilder{builder: &strings.Builder{}}
    sql, args := b.
        Select("id", "name", "age", "ctime", "mtime").
        From("user").
        Where(GT("id", 0), GT("age", 0)).
        OrderBy("id").
        Limit(0, 20).
        Query()
    rows, err := db.QueryContext(ctx, sql, args...)
    if err != nil {
        return nil, err
    }
    result := []*User{}
    err = ScanSlice(rows, &result)
    if err != nil {
        return nil, err
    }
    return result, nil
}

生成的查詢 SQL 語句和 args 如下:

SELECT `id`,`name`,`age`,`ctime`,`mtime` FROM `user` WHERE `id` > ? AND `age` > ? ORDER BY id LIMIT 20 OFFSET 0  [0 0]

自動生成

通過上面的使用的例子來看,我們的工作輕鬆了不少:

着實幫我們省了很大的麻煩。但是查詢字段還需要我們自己手寫,像這種

Select("id", "name", "age", "ctime", "mtime").

Table 對象如下:

type Table struct {
    TableName   string    // table name
    GoTableName string    // go struct name
    PackageName string    // package name
    Fields      []*Column // columns
}
type Column struct {
    ColumnName    string // column_name
    ColumnType    string // column_type
    ColumnComment string // column_comment
}

使用以上 Table 對象的模板代碼:

type {{.GoTableName}} struct {
    {{- range .Fields }}
        {{ .GoColumnName }} {{  .GoColumnType }} `json:"{{ .ColumnName }}"` // {{ .ColumnComment }}
    {{- end}}
}
const (
    table = "{{.TableName}}"
    {{- range .Fields}}
        {{ .GoColumnName}} = "{{.ColumnName}}" 
    {{- end }}
)
var columns = []string{
    {{- range .Fields}}
    {{ .GoColumnName}},
    {{- end }}
}

通過上面的模板我們用 user 表的建表 SQL 語句生成如下代碼:

type User struct {
    Id    int64     `json:"id"`    // id字段
    Name  string    `json:"name"`  // 名稱
    Age   int64     `json:"age"`   // 年齡
    Ctime time.Time `json:"ctime"` // 創建時間
    Mtime time.Time `json:"mtime"` // 更新時間
}
const (
    table = "user"
    Id = "id"
    Name = "name"
    Age = "age"
    Ctime = "ctime"
    Mtime = "mtime"
)
var Columns = []string{"id","name","age","ctime","mtime"}

那麼我們在查詢的時候就可以這樣使用

Select(Columns...)

通過模板自動生成代碼,可以大大的減輕開發編碼負擔,使我們從繁重的代碼中解放出來。

reflect 真的有必要嗎?

由於我們 SELECT 時選擇查找的字段和順序是不固定的,我們有可能 SELECT id, name, age FROM user,也可能 SELECT name, id FROM user,有很大的任意性,這種情況使用反射出來的結構體 tag 和查詢的列名來確定映射關係是必須的。但是有一種情況我們不需要用到反射,而且是一種最常用的情況,即:查詢的字段名和表結構的列名一致,且順序一致。這時候我們可以這麼寫,通過 DeepEqual 來判斷查詢字段和表結構字段是否一致且順序一致來決定是否通過反射還是通過傳統方法來創建對象。用傳統方式創建對象(如下圖第 12 行)令我們編碼痛苦,不過可以通過模板來自動生成下面的代碼,以避免手寫,這樣既靈活方便好用,性能又沒有損耗,看起來是一個比較完美的解決方案。

func FindUserNoReflect(b *SelectBuilder) ([]*User, error) {
    sql, args := b.Query()
    rows, err := db.QueryContext(ctx, sql, args...)
    if err != nil {
        return nil, err
    }
    result := []*User{}
    if DeepEqual(b.column, Columns) {
        defer rows.Close()
        for rows.Next() {
            a := &User{}
            if err := rows.Scan(&a.Id, &a.Name, &a.Age, &a.Ctime, &a.Mtime); err != nil {
                return nil, err
            }
            result = append(result, a)
        }
        if rows.Err() != nil {
            return nil, rows.Err()
        }
        return result, nil
    }
    err = ScanSlice(rows, &result)
    if err != nil {
        return nil, err
    }
    return result, nil
}

總結

  1. 通過 database/sql 庫開發有較大痛點,ORM 就是爲了解決以上問題而生,其存在是有意義的。

  2. ORM 兩個關鍵的部分是 SQLBuilder 和 Scanner 的實現。

  3. ORM Scanner 使用反射創建對象在性能上肯定會有一定的損失,但是帶來極大的靈活性, 同時在查詢全表字段這種特殊情況下規避使用反射來提高性能。

展望

通過表結構,我們可以生成對應的結構體和持久層增刪改查代碼,我們再往前擴展一步,能否通過表結構生成的 proto 格式的 message,以及一些常用的 CRUD GRPC rpc 接口定義。通過工具,我們甚至可以把前端的代碼都生成好,實現半自動化編程。我想這個是值得期待的。

參考資料

[1] https://github.com/ent/ent

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/06pZl4GpM0wAnyZmn7Hjfw