sync-WaitGroup 設計與實現

概述

sync.WaitGroup 可以等待一個併發執行的 goroutine 集合執行結束。

示例

通過一個小例子展示 sync.WaitGroup 的使用方法。

package main

import (
 "fmt"
 "strconv"
 "sync"
 "time"
)

type Task struct {
 ID   int
 Name string
}

func main() {
 tasks := make([]*Task, 0)

 // 添加 5 個任務
 for i := 1; i <= 5; i++ {
  tasks = append(tasks, &Task{
   ID:   i,
   Name: strconv.Itoa(i),
  })
 }

 var wg sync.WaitGroup

 // 開啓多個 goroutine 並行執行任務
 for _, task := range tasks {
  wg.Add(1)

  go func(t *Task) {
   defer wg.Done() // 任務完成

   fmt.Printf("Task %s starting ...\n", t.Name)

   time.Sleep(300 * time.Millisecond) // 模擬任務執行耗時
  }(task)
 }

 wg.Wait() // 等待所有任務執行結束
}
$ go run main.go

# 輸出如下
Task 5 starting ...
Task 1 starting ...
Task 4 starting ...
Task 3 starting ...
Task 2 starting ...

從輸出的結果中可以看到,雖然任務執行完成順序和添加順序並不一致,但是最終 5 個任務全部執行完成。

內部實現

我們來探究一下 sync.WaitGroup 的內部實現,文件路徑爲 $GOROOT/src/sync/waitgroup.go,筆者的 Go 版本爲 go1.19 linux/amd64

WaitGroup 對象

WaitGroup 對象表示併發 goroutine 集合的控制器,具體的使用方法爲:

根據 Go 內存模型的約束,goroutine 調用 Done 方法時,必須在對應的 Wait 方法之前調用,否則對應的 Wait 方法將永遠阻塞。

// WaitGroup 一旦使用後,就不能再複製
type WaitGroup struct {
 noCopy noCopy // 保證編譯期間不會發生複製
 
 state1 uint64
 state2 uint32
}

兩個字段表示的三個變量

三個語義變量

state1state2 兩個字段其實表示了三個語義變量,分別爲:

爲什麼不直接設置三個變量呢?

因爲 counter 和 waiter 計數器根據內存對齊情況放進一個 64 位整數里面,這是標準庫做的一個優化,將兩個計數器放進一個變量,這樣就可以在不加鎖的情況下,支持併發場景下的原子操作了,極大地提高了性能

state 方法

state 方法返回兩個指針變量,statep 變量表示 counter 和 waiter 計數器,semap 變量表示信號量。

stete 方法會根據 state1 字段的內存對齊位數,在必要時動態 "交換" 三個語義變量的順序

64 位對齊

在 32 位架構中,WaitGroup 對象初始化時分配的內存地址是隨機的,state1 字段起始的位置不一定 64 位對齊,所以需要和 state2 字段拼接起來,實現內存連續的情況下保證 64 位對齊。

非 64 位對齊

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
 // 判斷 state1 字段是否按照 64 位對齊
 if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
  // 如果 state1 字段是 64 位對齊,直接返回
  return &wg.state1, &wg.state2
 } else {
  // 如果 state1 是 32 位而非 64 位對齊
  // 這意味着 (&state1)+4 是 64 位對齊 (state1 字段 + 4, 正好是 state2 字段)
  // (&state1)+4 等於跨了兩個字段,所以是 64 位對齊 (兩個字段的內存是連續的)
  // 最後把兩個字段地址進行連接,在連接的基礎上實現地址交換
  state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
  return (*uint64)(unsafe.Pointer(&state[1]))&state[0]
 }
}

Add 方法

Add 方法增加 delta 個計數,內部會添加到對應的 counter 計數器上,如果 counter 變爲 0,所有阻塞在 Wait 方法上的 goroutine 都會立即完成並被釋放。

具體的調用規則如下:

  1. 當 counter == 0 並且 delta > 0 時,必須在 Wait 方法之前調用 Add 方法

  2. 當 counter > 0 並且 delta < 0 時,可以在任何時候調用 Add 方法

  3. 一般情況下,Add 方法應該在創建 goroutine 時或其他阻塞場景發生前調用

  4. 如果 WaitGroup 要重複使用,應該在所有 Wait 方法返回之後再繼續調用 Add 方法

func (wg *WaitGroup) Add(delta int) {
 statep, semap := wg.state() // 調用 state() 取出計數器和信號量
 
  ...
 
 state := atomic.AddUint64(statep, uint64(delta)<<32) // 增加計數器的值 
 v := int32(state >> 32) // 獲取計數器的值 (高位字節)
 w := uint32(state)  // 獲取等待者的值 (低位字節)
 
 ...
 
 if v < 0 {
  // 計數器不能爲負數 (出現了 BUG)
  panic("sync: negative WaitGroup counter")
 }
 
 // 等待者不等於 0, 說明已經有 goroutine 調用了 Wait 方法
 // 此時不允許再調用 Add 方法了 (參考規則 4)
 if w != 0 && delta > 0 && v == int32(delta) {
  panic("sync: WaitGroup misuse: Add called concurrently with Wait")
 }
 
 if v > 0 || w == 0 {
  // 如果計數器大於 0 或者沒有等待者,直接返回
  return
 }
 
 // 當等待者大於 0 並且計數器等於 0 (所有 goroutine 都調用了 Done 方法表示其結束執行)
 // 重置計數器和等待者爲 0
 *statep = 0
 // 喚醒所有等待者 (逐個阻塞調用)
 for ; w != 0; w-- {
  runtime_Semrelease(semap, false, 0)
 }
}

Done 方法

Done 方法簡單地封裝了一下 Add 方法 (等於調用 Add(-1)),提供了一個可讀性更高的操作原語。

func (wg *WaitGroup) Done() {
 wg.Add(-1)
}

Wait 方法

Wait 方法會進入阻塞,直到計數器的值等於 0。

func (wg *WaitGroup) Wait() {
 statep, semap := wg.state() // 調用 state() 取出計數器和信號量

 ...
     
 for {
  state := atomic.LoadUint64(statep)
  v := int32(state >> 32) // 獲取計數器的值 (高位字節)
  w := uint32(state)  // 獲取等待者的值 (低位字節)
  if v == 0 {
   // 計數器等於 0,直接返回
   return
  }
  
  // 計數器不等於 0,說明存在併發
  // 增加等待者的值
  if atomic.CompareAndSwapUint64(statep, state, state+1) {
    ...
   
   // 休眠當前 goroutine 等待喚醒
   runtime_Semacquire(semap)
   if *statep != 0 {
    // 等待者不等於 0, 說明 WaitGroup 對象被重複使用了 (參考規則 4)
    panic("sync: WaitGroup is reused before previous Wait has returned")
   }
   
   return
  }
 }
}

noCopy 對象

noCopy 對象可以添加到具體的結構體中,實現 "首次使用之後,無法被複制" 的功能 (由編譯器實現)。

noCopy.Lock 方法是一個空操作,由 go vet 工具鏈中的 -copylocks checker 參數指令使用。

type noCopy struct{}

func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

小結

sync.WaitGroup 的代碼實現中,有兩個非常重要的優化技巧值得我們學習:

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/B3GCAw3qNBFfK7MJ99aXfg