使用 Go 語言上線 TensorFlow 模型

【導讀】本文作者介紹了 Go 語言部署 TensorFlow 模型的實踐。

昨天搞了一天用 Go 語言部署 TensorFlow 模型,把整個過程記錄一下,以備大家參考(現在沒有題圖,以後在搞一個圖)。

首先我們要有一個已經保存好的 TensorFlow 模型,也就是. pb 文件。這個文件固化了計算圖和權重,Go 語言只需要根據這個代碼跑相應的 Session 就行了。關於如何產生. pb 文件,如果大家有興趣的話可以私信我,我可以根據大家的需求情況寫一份文檔。具體部分可以參見 https://www.tensorflow.org/api_docs/python/tf/saved_model 。

然後編譯 TF 的源代碼得到 libtensorflow.so 和 libtensorflow_framework.so。也可以用官網上的下載 https://www.tensorflow.org/install/lang_c(我沒用過,大家可以嘗試一下)。需要注意的是請務必保證保存模型的 TF 版本和這個動態鏈接庫的 TF 版本一致,不然的話後面的 Go 代碼可能會掛(大坑)。如果需要編譯的話可以參考 TF 的官方文檔,如果有興趣的話同上請私信我。

有了這個東西,爲了讓 ld 能夠找到這兩個文件,Linux 上需要設置 LIBRARY_PATH 和 LD_LIBRARY_PATH 這兩個環境變量。

export LIBRARY_PATH=[.so文件所在的目錄]
export LD_LIBRARY_PATH=[.so文件所在的目錄]

然後是下載我們的依賴包。可以使用下邊的命令。第一個是下載依賴,第二個是測試下載的依賴有沒有問題。如果第二個出錯,就證明前面的步驟有問題。

go get github.com/tensorflow/tensorflow/tensorflow/go
go get github.com/tensorflow/tensorflow/tensorflow/go

接下來就可以愉快的載入模型開始玩了。下面是載入模型的示例代碼。載入模型的時候需要給模型所在的文件夾和模型的名字(模型的名字可以用 saved_model_cli 這個工具來查看)。後面的一段是我自己家的,意思是打印出當前模型圖裏面所有的 Operator。這個代碼返回一個 tf.SavedModel 的 struct,這個 struct 有兩個成員,第一個是 Session,第二個是 Graph。如果大家對於 TF 的 python API 很熟應該知道這兩個是什麼東西。

func LoadModel(modelPath string, modelNames []string) *tf.SavedModel {
    model, err := tf.LoadSavedModel(modelPath, modelNames, nil) // 載入模型
    if err != nil {
        log.Fatal("LoadSavedModel(): %v", err)
    }

    log.Println("List possible ops in graphs") // 打印出所有的Operator
    for _, op := range model.Graph.Operations() {
        //log.Printf("Op name: %v, on device: %v", op.Name(), op.Device())
        log.Printf("Op name: %v", op.Name())
    }
    return model
}

有了 Session 和 Graph 之後,我們就能跑這個模型了。我這邊用的是 gin 這個 web 框架,直接把輸入的 JSON 編碼成 TensorFlow 接受的輸入,然後調用 Session.Run 方法來跑整個計算圖。這個方法傳三個參數,第一個參數是一個 map,把每個 tf.Output 類型映射成一個 tf.Tensor。前面一個在知道輸入的 Operator 的情況下,可以通過 Operator.Output(0) 方法拿到,後面一個,可以使用 tf.NewTensor 這個函數,傳入輸入的 Go 數組來生成。如果大家熟悉 TensorFlow 的 Python API 的話,我們會發現,第一個類似與 feed_dict 這個參數。第二個參數是輸出的張量的列表。我們同樣可以在拿到 Operator 以後,通過 Operator.Output(0) 方法拿到,注意要把他們包裝成一個 []Output 類型,即使裏面只有一個元素。第三個是不執行的 Operator 的列表,這裏我們設置成 nil。

func main () {
    m := LoadModel("../freeze_model", []string{"serve"})
    s := m.Session
    // ...
    ServeJSON := func (c *gin.Context) {
        var json map[string] int64
        if c.BindJSON(&json) == nil {
            log.Println(json)
        }
        ret, err:= s.Run(MapGraphInputs(CreateMapFromJSON(json), m),
            GetGraphOutputs([]string{"prob"}, m), nil)
        if err != nil {
            log.Fatal("Error in executing graph...", err)
        }
        // ...
    }
    // ...
}

然後我們編譯一下源代碼,跑一下,發現 gin 框架起來了,我們就可以用這個可執行文件做 web 服務了~ 這個可執行文件和. so 文件,以及模型文件一起,完全可以一起 copy 到 docker 的 container 裏面,這樣就可以用 k8s 愉快的和這個模型玩耍了。

轉自:

https://zhuanlan.zhihu.com/p/64311422

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/I0B05VlvNoSOSZr31M4MCw