Go 中看似簡單的 WaitGroup 源碼設計,竟然暗含這麼多知識?

Go 語言提供的協程 goroutine 可以讓我們很容易地寫出多線程程序,但是,如何讓這些併發執行的 goroutine 得到有效地控制,這是我們需要探討的問題。正如小菜刀在《Golang 併發控制簡述》中所述,Go 標準庫爲我們提供的同步原語中,鎖與原子操作注重控制 goroutine 之間的數據安全,WaitGroup、channel 與 Context 控制的是它們的併發行爲。關於原子操作channel 的實現原理小菜刀均有詳細地解析過。因此本文,我們將重點放在 WaitGroup 上。

初識 WaitGroup

WaitGroup 是 sync 包下的內容,用於控制協程間的同步。WaitGroup 使用場景同名字的含義一樣,當我們需要等待一組協程都執行完成以後,才能做後續的處理時,就可以考慮使用。

func main() {
    var wg sync.WaitGroup

    wg.Add(2) //worker number 2

    go func() {
        // worker 1 do something
        fmt.Println("goroutine 1 done!")
        wg.Done()
    }()

    go func() {
        // worker 2 do something
        fmt.Println("goroutine 2 done!")
        wg.Done()
    }()

    wg.Wait() // wait all waiter done
    fmt.Println("all work done!")
}

// output
goroutine 2 done
goroutine 1 done
all work done

可以看到 WaitGroup 的使用非常簡單,它提供了三個方法。雖然 goroutine 之間並不存在類似於父子關係,但是爲了方便理解,本文會將調用 Wait 函數的 goroutine 稱爲主 goroutine,調用 Done 函數的 goroutine 稱呼爲子 goroutine。

func (wg *WaitGroup) Add(delta int)  // 增加WaitGroup中的子goroutine計數值
func (wg *WaitGroup) Done()          // 當子goroutine任務完成,將計數值減1
func (wg *WaitGroup) Wait()          // 阻塞調用此方法的goroutine,直到計數值爲0

那麼它是如何實現的呢?在源碼src/sync/waitgroup.go中,我們可以看到它的核心源碼只有 100 行不到,十分地精練,非常值得學習。

前置知識

代碼少,不代表就實現簡單,易於理解。相反,如果讀者沒有下述中的前置知識,想要真正理解 WaitGroup 的實現是會比較費力的。在解析源碼之前,我們先過一遍這些知識(如果你都已經掌握,那就可以直接跳到後文的源碼解析部分)。

信號量

在學習操作系統時,我們知道信號量是一種保護共享資源的機制,用於解決多線程同步問題。信號量s是具有非負整數值的全局變量,只能由兩種特殊的操作來處理,這兩種操作稱爲PV

在 Go 的底層信號量函數中

這兩個信號量函數不止在 WaitGroup 中會用上,在《Go 精妙的互斥鎖設計》一文中,我們發現 Go 在設計互斥鎖的時候也少不了信號量的參與。

內存對齊

對於以下的結構體,你能回答出它佔用的內存是多少嗎

type Ins struct {
    x bool  // 1個字節
    y int32 // 4個字節
    z byte  // 1個字節
}

func main() {
    ins := Ins{}
    fmt.Printf("ins size: %d, align: %d\n", unsafe.Sizeof(ins), unsafe.Alignof(ins))
}

//output
ins size: 12, align: 4

按照結構體中字段的大小而言,ins對象佔用內存應該是 1+4+1=6 個字節,但是實際上確實 12 個字節,這就是內存對齊所致。從《CPU 緩存體系對 Go 程序的影響》一文中,我們知道 CPU 的內存讀取並不是一個字節一個字節地讀取的,而是一塊一塊的。因此,在類型的值在內存中對齊的情況下,計算機的加載或者寫入會很高效。

在聚合類型(結構體或數組)的內存所佔長度或許會比它元素所佔內存之和更大。編譯器會添加未使用的內存地址用於填充內存空隙,以確保連續的成員或元素相當於結構體或數組的起始地址是對齊的。

因此,在我們設計結構體時,當結構體成員的類型不同時,將相同類型的成員定義在相鄰位置可以更節省內存空間。

原子操作 CAS

CAS 是原子操作的一種,可用於在多線程編程中實現不被打斷的數據交換操作,從而避免多線程同時改寫某一數據時由於執行順序不確定性以及中斷的不可預知性產生的數據不一致問題。該操作通過將內存中的值與指定數據進行比較,當數值一樣時將內存中的數據替換爲新的值。關於 Go 中原子操作的底層實現,小菜刀在《同步原語的基石》一文中有詳細介紹。

移位運算 >> 與 <<

在之前關於鎖的文章《Go 精妙的互斥鎖設計》與《Go 更細粒度的讀寫鎖設計中》,我們能看到大量的位運算操作。靈活的位運算,能讓一個普通的數字變化出豐富的含義,這裏僅介紹下文中會用到的移位運算。

對於左移位運算 <<,按二進制形式將所有的數字向左移動對應的位數,高位捨棄,低位的空位補零。在數字沒有溢出的前提下,左移一位相當於乘以 2 的 1 次方,左移 n 位就相當於乘以 2 的 n 次方。

對於右移位運算 >>,按二進制形式把所有的數字向右移動對應位數,低位移出,高位的空位補符號位。右移一位相當於除 2,右移 n 位相當於除以 2 的 n 次方。這裏是取商,餘數就不要了。

移位運算也可以有很巧妙的操作,後文中我們會看到移位運算的高級運用。

unsafa.Pointer 指針與 uintptr

Go 中的指針可以分爲三類:1. 普通類型指針 * T,例如 * int;2. unsafe.Pointer 指針;3. uintptr。

unsafe.Pointer 是橋樑,可以讓任意類型的普通指針實現相互轉換,也可以將任意類型的指針轉換爲 uintptr 進行指針運算。但是,unsafe.Pointer 和任意類型指針的轉換可以讓我們將任意值寫入內存中,這會破壞 Go 原有的類型系統,同時由於不是所有的數值都是合法的內存地址,從 uintptr 到 unsafe.Pointer 的轉換同樣會破壞類型系統。因此,既然 Go 將該包定義爲 unsafe,那就不應該隨意使用。

源碼解析

本文基於 Go 源碼 1.15.7 版本

結構體

sync.WaitGroup 的結構體定義如下,它包括了一個 noCopy 的輔助字段,和一個具有複合意義的state1字段。

type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers do not ensure it. So we allocate 12 bytes and then use
    // the aligned 8 bytes in them as state, and the other 4 as storage
    // for the sema.
    state1 [3]uint32
}

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
  // 64位編譯器地址能被8整除,由此可判斷是否爲64位對齊
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

其中,noCopy字段是空結構體,它並不會佔用內存,編譯器也不會對其進行字節填充。它主要是爲了通過 go vet 工具來做靜態編譯檢查,防止開發者在使用 WaitGroup 過程中對其進行了複製,從而導致的安全隱患。關於這部分內容,可以閱讀《no copy 機制》詳細瞭解。

state1字段是一個長度爲 3 的uint32數組。它用於表示三部分內容:1. 通過Add()設置的子 goroutine 的計數值 counter;2. 通過Wait()陷入阻塞的 waiter 數;3. 信號量 semap。

由於後續是對 uint64 類型的statep進行操作,而 64 位整數的原子操作需要 64 位對齊,32 位的編譯器並不能保證這一點。因此,在 64 位與 32 位的環境下,state1字段的組成含義是不相同的。

需要注意的是,當我們初始化一個 WaitGroup 對象時,其 counter 值、waiter 值、semap 值均爲 0。

Add 函數

Add()函數的入參是一個整型,它可正可負,是對 counter 數值的更改。如果 counter 數值變爲 0,那麼所有阻塞在Wait()函數的 waiter 將會被喚醒;如果 counter 數值爲負值,將引起 panic。

我們將競態檢測部分的代碼去掉,Add()函數的實現源碼如下

func (wg *WaitGroup) Add(delta int) {
  // 獲取包含counter與waiter的複合狀態statep,表示信號量值的semap
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)

    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }

    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

    if v > 0 || w == 0 {
        return
    }

    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

  // 如果執行到這,一定是 counter=0,waiter>0
  // 能執行到這,一定是執行了Add(-x)的goroutine
  // 它的執行,代表所有子goroutine已經完成了任務
  // 因此,我們需要將複合狀態全部歸0,並釋放掉waiter個數的信號量
    *statep = 0
    for ; w != 0; w-- {
    // 釋放信號量,執行一次就將喚醒一個阻塞的waiter
        runtime_Semrelease(semap, false, 0)
    }
}

代碼非常精簡,我們接下來對關鍵部分進行剖析。

    state := atomic.AddUint64(statep, uint64(delta)<<32)  // 新增counter數值delta
    v := int32(state >> 32)   // 獲取counter值
    w := uint32(state)        // 獲取waiter值

此時的statep是一個uint64數值,如果此時statep中包含的 counter 數爲 2,waiter 爲 1,輸入 delta 爲 1,那麼這三行代碼的邏輯過程如下圖所示。

在得到當前 counter 數 v 與 waiter 數 w 後,會對它們的值進行判斷,分幾種情況。

    // 情況1:這是很低級的錯誤,counter值不能爲負
  if v < 0 {
        panic("sync: negative WaitGroup counter")
    }

  // 情況2:misuse引起panic 
  // 因爲wg其實是可以用複用的,但是下一次複用的基礎是需要將所有的狀態重置爲0纔可以
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

  // 情況3:本次Add操作只負責增加counter值,直接返回即可。
  // 如果此時counter值大於0,喚醒的操作留給之後的Add調用者(執行Add(negative int))
  // 如果waiter值爲0,代表此時還沒有阻塞的waiter
    if v > 0 || w == 0 {
        return
    }

  // 情況4: misuse引起的panic
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

關於 misuse 和 reused 引發 panic 的情況,如果沒有示例錯誤代碼,其實是比較難解釋的。值得高興的是,在 Go 源碼中給出了錯誤使用示範,這些例子位於src/sync/waitgroup_test.go文件下,想深入瞭解的讀者可以去看以下三個測試函數中的示例。

func TestWaitGroupMisuse(t *testing.T)
func TestWaitGroupMisuse2(t *testing.T)
func TestWaitGroupMisuse3(t *testing.T)
Done 函數

Done()函數比較簡單,就是調用Add(-1)。在實際使用時,當子 goroutine 任務完成之後,就應該調用Done()函數。

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}
Wait 函數

如果 WaitGroup 中的 counter 值大於 0,那麼執行Wait()函數的主 goroutine 會將 waiter 值加 1,並阻塞等待該值爲 0,才能繼續執行後續代碼。

我們將競態檢測部分的代碼去掉,Wait()函數的實現源碼如下

func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep) // 原子讀取複合狀態statep
        v := int32(state >> 32)            // 獲取counter值
        w := uint32(state)                 // 獲取waiter值
    // 如果此時v==0,證明已經沒有待執行任務的子goroutine,直接退出即可。
        if v == 0 {
            return
        }
        // 如果在執行CAS原子操作和讀取複合狀態之間,沒有其他goroutine更改了複合狀態
    // 那麼就將waiter值+1,否則:進入下一輪循環,重新讀取複合狀態
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
      // 對waiter值累加成功後
      // 等待Add函數中調用 runtime_Semrelease 喚醒自己
            runtime_Semacquire(semap)
      // reused 引發panic
      // 在當前goroutine被喚醒時,由於喚醒自己的goroutine通過調用Add方法時
      // 已經通過 *statep = 0 語句做了重置操作
      // 此時的複合狀態位不爲0,就是因爲還未等Waiter執行完Wait,WaitGroup就已經發生了複用
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

總結

要看懂 WaitGroup 的源碼實現,我們需要有一些前置知識,例如信號量、內存對齊、原子操作、移位運算和指針轉換等。

但其實 WaitGroup 的實現思路還是蠻簡單的,通過結構體字段state1維護了兩個計數器和一個信號量,計數器分別是通過Add()添加的子 goroutine 的計數值 counter,通過Wait()陷入阻塞的 waiter 數,信號量用於阻塞與喚醒 Waiter。當執行Add(positive n)時,counter +=n,表明新增 n 個子 goroutine 執行任務。每個子 goroutine 完成任務之後,需要調用Done()函數將 counter 值減 1,當最後一個子 goroutine 完成時,counter 值會是 0,此時就需要喚醒阻塞在Wait()調用中的 Waiter。

但是,在使用 WaitGroup 時,有幾點需要注意

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/V2x0Nw3Y8lZxHYJh_yQ0dQ