rust 實戰 - 實現一個線程工作池 ThreadPool

關注「Rust 編程指北」,一起學習 Rust,給未來投資

如何實現一個線程池

線程池:一種線程使用模式。線程過多會帶來調度開銷,進而影響緩存局部性和整體性能。而線程池維護着多個線程,等待着監督管理者分配可併發執行的任務。這避免了在處理短時間任務時創建與銷燬線程的代價。線程池不僅能夠保證內核的充分利用,還能防止過分調度。可用線程數量應該取決於可用的併發處理器、處理器內核、內存、網絡 sockets 等的數量。例如,對於計算密集型任務,線程數一般取 cpu 數量 + 2 比較合適,線程數過多會導致額外的線程切換開銷。

如何定義線程池 Pool 呢,首先最大線程數量肯定要作爲線程池的一個屬性,並且在 new Pool 時創建指定的線程。

線程池 Pool

pub struct Pool {
  max_workers: usize, // 定義最大線程數
}

impl Pool {
  fn new(max_workers: usize) -> Pool {}
  fn execute<F>(&self, f:F) where F: FnOnce() + 'static + Send {}
}

execute來執行任務,F: FnOnce() + 'static + Send 是使用 thread::spawn 線程執行需要滿足的 trait, 代表 F 是一個能在線程裏執行的閉包函數。

另一點自然而然會想到在 Pool 添加一個線程數組, 這個線程數組就是用來執行任務的。比如Vec<Thread> balabala。這裏的線程是活的,是一個個不斷接受任務然後執行的實體。
可以看作在一個線程裏不斷執行獲取任務並執行的 Worker。

struct Worker where
{
    _id: usize, // worker 編號
}

要怎麼把任務發送給 Worker 執行呢?mpsc(multi producer single consumer) 多生產者單消費者可以滿足我們的需求,let (tx, rx) = mpsc::channel() 可以獲取到一對發送端和接收端。
把發送端添加到 Pool 裏面,把接收端添加到 Worker 裏面。Pool 通過 channel 將任務發送給多個 worker 消費執行。

這裏有一點需要特別注意,channel 的接收端 receiver 需要安全的在多個線程間共享,因此需要用Arc<Mutex::<T>>來包裹起來,也就是用鎖來解決併發衝突。

Pool 的完整定義

pub struct Pool {
    workers: Vec<Worker>,
    max_workers: usize,
    sender: mpsc::Sender<Message>
}

該是時候定義我們要發給 Worker 的消息 Message 了
定義如下的枚舉值

type Job = Box<dyn FnOnce() + 'static + Send>;
enum Message {
    ByeBye,
    NewJob(Job),
}

Job 是一個要發送給 Worker 執行的閉包函數,這裏 ByeBye 用來通知 Worker 可以終止當前的執行,退出線程。

只剩下實現 Worker 和 Pool 的具體邏輯了。

Worker 的實現

impl Worker
{
    fn new(id: usize, receiver: Arc::<Mutex<mpsc::Receiver<Message>>>) -> Worker {
        let t = thread::spawn( move || {
            loop {
                let receiver = receiver.lock().unwrap();
                let message=  receiver.recv().unwrap();
                match message {
                    Message::NewJob(job) => {
                        println!("do job from worker[{}]", id);
                        job();
                    },
                    Message::ByeBye => {
                        println!("ByeBye from worker[{}]", id);
                        break
                    },
                }  
            }
        });

        Worker {
            _id: id,
            t: Some(t),
        }
    }
}

let message = receiver.lock().unwrap().recv().unwrap(); 這裏獲取鎖後從 receiver 獲取到消息體,然後 let message 結束後 rust 的生命週期會自動釋放掉鎖。
但如果寫成

while let message = receiver.lock().unwrap().recv().unwrap() {
};

while let 後面整個括號都是一個作用域,要在這個作用域結束後,鎖纔會釋放,比上面 let message 要鎖定久時間。
rust 的 mutex 鎖沒有對應的 unlock 方法,由 mutex 的生命週期管理。

我們給 Pool 實現Drop trait, 讓 Pool 被銷燬時,自動暫停掉 worker 線程的執行。

impl Drop for Pool {
    fn drop(&mut self) {
        for _ in 0..self.max_workers {
            self.sender.send(Message::ByeBye).unwrap();
        }
        for w in self.workers.iter_mut() {
            if let Some(t) = w.t.take() {
                t.join().unwrap();
            }
        }
    }
}

drop 方法裏面用了兩個循環,而不是在一個循環裏做完兩件事?

for w in self.workers.iter_mut() {
    if let Some(t) = w.t.take() {
        self.sender.send(Message::ByeBye).unwrap();
        t.join().unwrap();
    }
}

這裏面隱藏了一個會造成死鎖的陷阱,比如兩個 Worker, 在單個循環裏面迭代所有 Worker,再將終止信息發送給通道後,直接調用 join,
我們預期是第一個 worker 要收到消息,並且等他執行完。當情況可能是第二個 worker 獲取到了消息,第一個 worker 沒有獲取到,那接下來的 join 就會阻塞造成死鎖。

注意到沒有,Worker 是被包裝在 Option 內的,這裏有兩個點需要注意

  1. t.join 需要持有 t 的所有權

  2. 在我們這種情況下,self.workers 只能作爲引用被 for 循環迭代。

這裏考慮讓 Worker 持有Option<JoinHandle<()>>,後續可以通過在 Option 上調用 take 方法將 Some 變體的值移出來,並在原來的位置留下 None 變體。
換而言之,讓運行中的 worker 持有 Some 的變體,清理 worker 時,可以使用 None 替換掉 Some,從而讓 Worker 失去可以運行的線程

struct Worker where
{
    _id: usize,
    t: Option<JoinHandle<()>>,
}

要點總結

完整代碼

use std::thread::{self, JoinHandle};
use std::sync::{Arc, mpsc, Mutex};


type Job = Box<dyn FnOnce() + 'static + Send>;
enum Message {
    ByeBye,
    NewJob(Job),
}

struct Worker where
{
    _id: usize,
    t: Option<JoinHandle<()>>,
}

impl Worker
{
    fn new(id: usize, receiver: Arc::<Mutex<mpsc::Receiver<Message>>>) -> Worker {
        let t = thread::spawn( move || {
            loop {
                let message = receiver.lock().unwrap().recv().unwrap();
                match message {
                    Message::NewJob(job) => {
                        println!("do job from worker[{}]", id);
                        job();
                    },
                    Message::ByeBye => {
                        println!("ByeBye from worker[{}]", id);
                        break
                    },
                }  
            }
        });

        Worker {
            _id: id,
            t: Some(t),
        }
    }
}

pub struct Pool {
    workers: Vec<Worker>,
    max_workers: usize,
    sender: mpsc::Sender<Message>
}

impl Pool where {
    pub fn new(max_workers: usize) -> Pool {
        if max_workers == 0 {
            panic!("max_workers must be greater than zero!")
        }
        let (tx, rx) = mpsc::channel();

        let mut workers = Vec::with_capacity(max_workers);
        let receiver = Arc::new(Mutex::new(rx));
        for i in 0..max_workers {
            workers.push(Worker::new(i, Arc::clone(&receiver)));
        }

        Pool { workers: workers, max_workers: max_workers, sender: tx }
    }
    
    pub fn execute<F>(&self, f:F) where F: FnOnce() + 'static + Send
    {

        let job = Message::NewJob(Box::new(f));
        self.sender.send(job).unwrap();
    }
}

impl Drop for Pool {
    fn drop(&mut self) {
        for _ in 0..self.max_workers {
            self.sender.send(Message::ByeBye).unwrap();
        }
        for w in self.workers {
            if let Some(t) = w.t.take() {
                t.join().unwrap();
            }
        }
    }
}


#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn it_works() {
        let p = Pool::new(4);
        p.execute(|| println!("do new job1"));
        p.execute(|| println!("do new job2"));
        p.execute(|| println!("do new job3"));
        p.execute(|| println!("do new job4"));
    }
}

來自:https://www.cnblogs.com/linyihai/p/15885327.html,作者 yihailin

覺得不錯,點個贊吧

掃碼關注「Rust 編程指北」

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/RE5s8X-nfm536JWXNGDqHA