Go 語言 errgroup 庫:強大的併發控制工具

errgroup 是官方 Go 庫 x 中的一個實用工具,用於併發執行多個 goroutine 並處理錯誤。它基於 sync.WaitGroup 實現了 errgroup.Group,爲併發編程提供了更強大的功能。

errgroup 的優勢

與 sync.WaitGroup 相比,errgroup.Group 具有以下優勢:

sync.WaitGroup 使用示例

在介紹 errgroup.Group 之前,我們先回顧一下 sync.WaitGroup 的用法。

package main

import (
    "fmt"
    "net/http"
    "sync"
)

func main() {
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/", 
    }
    var err error

    var wg sync.WaitGroup 

    for _, url := range urls {
        wg.Add(1) 

        gofunc() {
            defer wg.Done() 

            resp, e := http.Get(url)
            if e != nil { 
                err = e
                return
            }
            defer resp.Body.Close()
            fmt.Printf("fetch url %s status %s\n", url, resp.Status)
        }()
    }

    wg.Wait()
    if err != nil { 
        fmt.Printf("Error: %s\n", err)
    }
}

執行結果:

$ go run examples/main.go
fetch url http://www.google.com/ status 200 OK
fetch url http://www.golang.org/ status 200 OK
Error: Get "http://www.somestupidname.com/": dial tcp: lookup www.somestupidname.com: no such host

sync.WaitGroup 的典型用法:

var wg sync.WaitGroup

for ... {
    wg.Add(1)

    go func() {
        defer wg.Done()
        // do something
    }()
}

wg.Wait()

errgroup.Group 使用示例

基本用法

errgroup.Group 的使用模式與 sync.WaitGroup 類似。

package main

import (
    "fmt"
    "net/http"
    "golang.org/x/sync/errgroup"
)

func main() {
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/", 
    }

    var g errgroup.Group 

    for _, url := range urls {
        g.Go(func() error {
            resp, err := http.Get(url)
            if err != nil {
                return err 
            }
            defer resp.Body.Close()
            fmt.Printf("fetch url %s status %s\n", url, resp.Status)
            returnnil
        })
    }

    if err := g.Wait(); err != nil {
        fmt.Printf("Error: %s\n", err)
    }
}

執行結果:

$ go run examples/main.go
fetch url http://www.google.com/ status 200 OK
fetch url http://www.golang.org/ status 200 OK
Error: Get "http://www.somestupidname.com/": dial tcp: lookup www.somestupidname.com: no such host

Context 取消

errgroup 提供了 errgroup.WithContext 來添加取消功能。

package main

import (
    "context"
    "fmt"
    "net/http"
    "sync"
    "golang.org/x/sync/errgroup"
)

func main() {
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/", 
    }

    g, ctx := errgroup.WithContext(context.Background())

    var result sync.Map

    for _, url := range urls {
        g.Go(func() error {
            req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
            if err != nil {
                return err 
            }

            resp, err := http.DefaultClient.Do(req)
            if err != nil {
                return err 
            }
            defer resp.Body.Close()

            result.Store(url, resp.Status)
            returnnil
        })
    }

    if err := g.Wait(); err != nil {
        fmt.Println("Error: ", err)
    }

    result.Range(func(key, value any) bool {
        fmt.Printf("fetch url %s status %s\n", key, value)
        returntrue
    })
}

執行結果:

$ go run examples/withcontext/main.go
Error:  Get "http://www.somestupidname.com/": dial tcp: lookup www.somestupidname.com: no such host
fetch url http://www.google.com/ status 200 OK

由於請求 http://www.somestupidname.com/ 報錯,程序取消了對 http://www.golang.org/ 的請求。

限制併發數量

errgroup 提供了 errgroup.SetLimit 來限制併發執行的 goroutine 數量。

package main

import (
    "fmt"
    "time"
    "golang.org/x/sync/errgroup"
)

func main() {
    var g errgroup.Group
    g.SetLimit(3)

    for i := 1; i <= 10; i++ {
        g.Go(func() error {
            fmt.Printf("Goroutine %d is starting\n", i)
            time.Sleep(2 * time.Second) 
            fmt.Printf("Goroutine %d is done\n", i)
            returnnil
        })
    }

    if err := g.Wait(); err != nil {
        fmt.Printf("Encountered an error: %v\n", err)
    }

    fmt.Println("All goroutines complete.")
}

執行結果:

$  go run examples/main.go
Goroutine 3 is starting
Goroutine 1 is starting
Goroutine 2 is starting
Goroutine 2 is done
Goroutine 1 is done
Goroutine 5 is starting
Goroutine 3 is done
Goroutine 6 is starting
Goroutine 4 is starting
Goroutine 6 is done
Goroutine 5 is done
Goroutine 8 is starting
Goroutine 4 is done
Goroutine 7 is starting
Goroutine 9 is starting
Goroutine 9 is done
Goroutine 8 is done
Goroutine 10 is starting
Goroutine 7 is done
Goroutine 10 is done
All goroutines complete.

嘗試啓動

errgroup 提供了 errgroup.TryGo 來嘗試啓動任務,需要與 errgroup.SetLimit 配合使用。

package main

import (
    "fmt"
    "time"
    "golang.org/x/sync/errgroup"
)

func main() {
    var g errgroup.Group
    g.SetLimit(3)

    for i := 1; i <= 10; i++ {
        if g.TryGo(func() error {
            fmt.Printf("Goroutine %d is starting\n", i)
            time.Sleep(2 * time.Second) 
            fmt.Printf("Goroutine %d is done\n", i)
            returnnil
        }) {
            fmt.Printf("Goroutine %d started successfully\n", i)
        } else {
            fmt.Printf("Goroutine %d could not start (limit reached)\n", i)
        }
    }

    if err := g.Wait(); err != nil {
        fmt.Printf("Encountered an error: %v\n", err)
    }

    fmt.Println("All goroutines complete.")
}

執行結果:

$ go run examples/main.go
Goroutine 1 started successfully
Goroutine 1 is starting
Goroutine 2 is starting
Goroutine 2 started successfully
Goroutine 3 started successfully
Goroutine 4 could not start (limit reached)
Goroutine 5 could not start (limit reached)
Goroutine 6 could not start (limit reached)
Goroutine 7 could not start (limit reached)
Goroutine 8 could not start (limit reached)
Goroutine 9 could not start (limit reached)
Goroutine 10 could not start (limit reached)
Goroutine 3 is starting
Goroutine 2 is done
Goroutine 3 is done
Goroutine 1 is done
All goroutines complete.

源碼解讀

errgroup 的源碼主要由 3 個文件組成:

核心結構

type token struct{}

type Group struct {
    cancel func(error)
    wg sync.WaitGroup
    sem chan token
    errOnce sync.Once
    err     error
}

主要方法

SetLimit:限制併發數量

func (g *Group) SetLimit(n int) {
    if n < 0 {
        g.sem = nil
        return
    }
    if len(g.sem) != 0 {
        panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
    }
    g.sem = make(chan token, n)
}

Go:啓動新的協程執行任務

func (g *Group) Go(f func() error) {
    if g.sem != nil {
        g.sem <- token{}
    }

    g.wg.Add(1)
    gofunc() {
        defer g.done()

        if err := f(); err != nil {
            g.errOnce.Do(func() {
                g.err = err
                if g.cancel != nil {
                    g.cancel(g.err)
                }
            })
        }
    }()
}

Wait:等待所有任務完成並返回第一個錯誤

func (g *Group) Wait() error {
    g.wg.Wait()
    if g.cancel != nil {
        g.cancel(g.err)
    }
    return g.err
}

TryGo:嘗試啓動任務

func (g *Group) TryGo(f func() error) bool {
    if g.sem != nil {
        select {
        case g.sem <- token{}:
        default:
            returnfalse
        }
    }

    g.wg.Add(1)
    gofunc() {
        defer g.done()

        if err := f(); err != nil {
            g.errOnce.Do(func() {
                g.err = err
                if g.cancel != nil {
                    g.cancel(g.err)
                }
            })
        }
    }()
    returntrue
}

結論

errgroup 是一個官方擴展庫,在 sync.WaitGroup 的基礎上增加了錯誤處理能力,提供了同步、錯誤傳播、context 取消等功能。其 WithContext 方法可以添加取消功能,SetLimit 可以限制併發數量,TryGo 可以嘗試啓動任務。源碼設計巧妙,值得參考。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/cTJzShossbGrrt6KewEnRg