幾種限流算法的 go 語言實現

【導讀】不依賴外部庫的情況下,限流算法有什麼實現的思路?本文介紹了 3 種實現限流的方式。

一、漏桶算法

package main

import (
 "fmt"
 "sync"
 "time"
)

// 每個請求來了,把需要執行的業務邏輯封裝成Task,放入木桶,等待worker取出執行
type Task struct {
 handler func() Result // worker從木桶中取出請求對象後要執行的業務邏輯函數
 resChan chan Result   // 等待worker執行並返回結果的channel
 taskID  int
}

// 封裝業務邏輯的執行結果
type Result struct {
}

// 模擬業務邏輯的函數
func handler() Result {
 time.Sleep(300 * time.Millisecond)
 return Result{}
}

func NewTask(id int) Task {
 return Task{
  handler: handler,
  resChan: make(chan Result),
  taskID:  id,
 }
}

// 漏桶
type LeakyBucket struct {
 BucketSize int       // 木桶的大小
 NumWorker  int       // 同時從木桶中獲取任務執行的worker數量
 bucket     chan Task // 存方任務的木桶
}

func NewLeakyBucket(bucketSize int, numWorker int) *LeakyBucket {
 return &LeakyBucket{
  BucketSize: bucketSize,
  NumWorker:  numWorker,
  bucket:     make(chan Task, bucketSize),
 }
}

func (b *LeakyBucket) validate(task Task) bool {
 // 如果木桶已經滿了,返回false
 select {
 case b.bucket <- task:
 default:
  fmt.Printf("request[id=%d] is refused\n", task.taskID)
  return false
 }

 // 等待worker執行
 <-task.resChan
 fmt.Printf("request[id=%d] is run\n", task.taskID)
 return true
}

func (b *LeakyBucket) Start() {
 // 開啓worker從木桶拉取任務執行
 go func() {
  for i := 0; i < b.NumWorker; i++ {
   go func() {
    for {
     task := <-b.bucket
     result := task.handler()
     task.resChan <- result
    }
   }()
  }
 }()
}

func main() {
 bucket := NewLeakyBucket(10, 4)
 bucket.Start()

 var wg sync.WaitGroup
 for i := 0; i < 20; i++ {
  wg.Add(1)
  go func(id int) {
   defer wg.Done()
   task := NewTask(id)
   bucket.validate(task)
  }(i)
 }
 wg.Wait()
}

二、令牌桶算法

package main

import (
 "fmt"
 "sync"
 "time"
)

// 併發訪問同一個user_id/ip的記錄需要上鎖
var recordMu map[string]*sync.RWMutex

func init() {
 recordMu = make(map[string]*sync.RWMutex)
}

func max(a, b int) int {
 if a > b {
  return a
 }
 return b
}

type TokenBucket struct {
 BucketSize int // 木桶內的容量:最多可以存放多少個令牌
 TokenRate time.Duration // 多長時間生成一個令牌
 records map[string]*record // 報錯user_id/ip的訪問記錄
}

// 上次訪問時的時間戳和令牌數
type record struct {
 last time.Time
 token int
}

func NewTokenBucket(bucketSize int, tokenRate time.Duration) *TokenBucket {
 return &TokenBucket{
  BucketSize: bucketSize,
  TokenRate:  tokenRate,
  records:    make(map[string]*record),
 }
}

func (t *TokenBucket) getUidOrIp() string {
 // 獲取請求用戶的user_id或者ip地址
 return "127.0.0.1"
}

// 獲取這個user_id/ip上次訪問時的時間戳和令牌數
func (t *TokenBucket) getRecord(uidOrIp string) *record {
 if r, ok := t.records[uidOrIp]; ok {
  return r
 }
 return &record{}
}

// 保存user_id/ip最近一次請求時的時間戳和令牌數量
func (t *TokenBucket) storeRecord(uidOrIp string, r *record) {
 t.records[uidOrIp] = r
}

// 驗證是否能獲取一個令牌
func (t *TokenBucket) validate(uidOrIp string) bool {
 // 併發修改同一個用戶的記錄上寫鎖
 rl, ok := recordMu[uidOrIp]
 if !ok {
  var mu sync.RWMutex
  rl = &mu
  recordMu[uidOrIp] = rl
 }
 rl.Lock()
 defer rl.Unlock()

 r := t.getRecord(uidOrIp)
 now := time.Now()
 if r.last.IsZero() {
  // 第一次訪問初始化爲最大令牌數
  r.last, r.token = now, t.BucketSize
 } else {
  if r.last.Add(t.TokenRate).Before(now) {
   // 如果與上次請求的間隔超過了token rate
   // 則增加令牌,更新last
   r.token += max(int(now.Sub(r.last) / t.TokenRate), t.BucketSize)
   r.last = now
  }
 }
 var result bool
 if r.token > 0 {
  // 如果令牌數大於1,取走一個令牌,validate結果爲true
  r.token--
  result = true
 }

 // 保存最新的record
 t.storeRecord(uidOrIp, r)
 return result
}

// 返回是否被限流
func (t *TokenBucket) IsLimited() bool {
 return !t.validate(t.getUidOrIp())
}

func main() {
 tokenBucket := NewTokenBucket(5, 100*time.Millisecond)
 for i := 0; i< 6; i++ {
  fmt.Println(tokenBucket.IsLimited())
 }
 time.Sleep(100 * time.Millisecond)
 fmt.Println(tokenBucket.IsLimited())
}

三、滑動時間窗口算法

package main

import (
 "fmt"
 "sync"
 "time"
)

var winMu map[string]*sync.RWMutex

func init() {
 winMu = make(map[string]*sync.RWMutex)
}

type timeSlot struct {
 timestamp time.Time // 這個timeSlot的時間起點
 count     int       // 落在這個timeSlot內的請求數
}

func countReq(win []*timeSlot) int {
 var count int
 for _, ts := range win {
  count += ts.count
 }
 return count
}

type SlidingWindowLimiter struct {
 SlotDuration time.Duration // time slot的長度
 WinDuration  time.Duration // sliding window的長度
 numSlots     int           // window內最多有多少個slot
 windows      map[string][]*timeSlot
 maxReq       int // win duration內允許的最大請求數
}

func NewSliding(slotDuration time.Duration, winDuration time.Duration, maxReq int) *SlidingWindowLimiter {
 return &SlidingWindowLimiter{
  SlotDuration: slotDuration,
  WinDuration:  winDuration,
  numSlots:     int(winDuration / slotDuration),
  windows:      make(map[string][]*timeSlot),
  maxReq:       maxReq,
 }
}

// 獲取user_id/ip的時間窗口
func (l *SlidingWindowLimiter) getWindow(uidOrIp string) []*timeSlot {
 win, ok := l.windows[uidOrIp]
 if !ok {
  win = make([]*timeSlot, 0, l.numSlots)
 }
 return win
}

func (l *SlidingWindowLimiter) storeWindow(uidOrIp string, win []*timeSlot) {
 l.windows[uidOrIp] = win
}

func (l *SlidingWindowLimiter) validate(uidOrIp string) bool {
 // 同一user_id/ip併發安全
 mu, ok := winMu[uidOrIp]
 if !ok {
  var m sync.RWMutex
  mu = &m
  winMu[uidOrIp] = mu
 }
 mu.Lock()
 defer mu.Unlock()

 win := l.getWindow(uidOrIp)
 now := time.Now()
 // 已經過期的time slot移出時間窗
 timeoutOffset := -1
 for i, ts := range win {
  if ts.timestamp.Add(l.WinDuration).After(now) {
   break
  }
  timeoutOffset = i
 }
 if timeoutOffset > -1 {
  win = win[timeoutOffset+1:]
 }

 // 判斷請求是否超限
 var result bool
 if countReq(win) < l.maxReq {
  result = true
 }

 // 記錄這次的請求數
 var lastSlot *timeSlot
 if len(win) > 0 {
  lastSlot = win[len(win)-1]
  if lastSlot.timestamp.Add(l.SlotDuration).Before(now) {
   lastSlot = &timeSlot{timestamp: now, count: 1}
   win = append(win, lastSlot)
  } else {
   lastSlot.count++
  }
 } else {
  lastSlot = &timeSlot{timestamp: now, count: 1}
  win = append(win, lastSlot)
 }

 l.storeWindow(uidOrIp, win)

 return result
}

func (l *SlidingWindowLimiter) getUidOrIp() string {
 return "127.0.0.1"
}

func (l *SlidingWindowLimiter) IsLimited() bool {
 return !l.validate(l.getUidOrIp())
}

func main() {
 limiter := NewSliding(100*time.Millisecond, time.Second, 10)
 for i := 0; i < 5; i++ {
  fmt.Println(limiter.IsLimited())
 }
 time.Sleep(100 * time.Millisecond)
 for i := 0; i < 5; i++ {
  fmt.Println(limiter.IsLimited())
 }
 fmt.Println(limiter.IsLimited())
 for _, v := range limiter.windows[limiter.getUidOrIp()] {
  fmt.Println(v.timestamp, v.count)
 }

 fmt.Println("a thousand years later...")
 time.Sleep(time.Second)
 for i := 0; i < 7; i++ {
  fmt.Println(limiter.IsLimited())
 }
 for _, v := range limiter.windows[limiter.getUidOrIp()] {
  fmt.Println(v.timestamp, v.count)
 }
}

轉自:mikellxy

juejin.cn/post/6844904051344146439

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/5wPpHi8wwaGjen71qXon_A