1use crate::job::{JobFifo, JobRef, StackJob};
2use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
3use crate::sleep::Sleep;
4use crate::sync::Mutex;
5use crate::unwind;
6use crate::{
7 ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
8 Yield,
9};
10use crossbeam_deque::{Injector, Steal, Stealer, Worker};
11use std::cell::Cell;
12use std::fmt;
13use std::hash::{DefaultHasher, Hasher};
14use std::io;
15use std::mem;
16use std::ptr;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::{Arc, Once};
19use std::thread;
20
21pub struct ThreadBuilder {
23 name: Option<String>,
24 stack_size: Option<usize>,
25 worker: Worker<JobRef>,
26 stealer: Stealer<JobRef>,
27 registry: Arc<Registry>,
28 index: usize,
29}
30
31impl ThreadBuilder {
32 pub fn index(&self) -> usize {
34 self.index
35 }
36
37 pub fn name(&self) -> Option<&str> {
39 self.name.as_deref()
40 }
41
42 pub fn stack_size(&self) -> Option<usize> {
44 self.stack_size
45 }
46
47 pub fn run(self) {
50 unsafe { main_loop(self) }
51 }
52}
53
54impl fmt::Debug for ThreadBuilder {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_struct("ThreadBuilder")
57 .field("pool", &self.registry.id())
58 .field("index", &self.index)
59 .field("name", &self.name)
60 .field("stack_size", &self.stack_size)
61 .finish()
62 }
63}
64
65pub trait ThreadSpawn {
70 private_decl! {}
71
72 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
75}
76
77#[derive(Debug, Default)]
82pub struct DefaultSpawn;
83
84impl ThreadSpawn for DefaultSpawn {
85 private_impl! {}
86
87 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
88 let mut b = thread::Builder::new();
89 if let Some(name) = thread.name() {
90 b = b.name(name.to_owned());
91 }
92 if let Some(stack_size) = thread.stack_size() {
93 b = b.stack_size(stack_size);
94 }
95 b.spawn(|| thread.run())?;
96 Ok(())
97 }
98}
99
100#[derive(Debug)]
105pub struct CustomSpawn<F>(F);
106
107impl<F> CustomSpawn<F>
108where
109 F: FnMut(ThreadBuilder) -> io::Result<()>,
110{
111 pub(super) fn new(spawn: F) -> Self {
112 CustomSpawn(spawn)
113 }
114}
115
116impl<F> ThreadSpawn for CustomSpawn<F>
117where
118 F: FnMut(ThreadBuilder) -> io::Result<()>,
119{
120 private_impl! {}
121
122 #[inline]
123 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
124 (self.0)(thread)
125 }
126}
127
128pub(super) struct Registry {
129 thread_infos: Vec<ThreadInfo>,
130 sleep: Sleep,
131 injected_jobs: Injector<JobRef>,
132 broadcasts: Mutex<Vec<Worker<JobRef>>>,
133 panic_handler: Option<Box<PanicHandler>>,
134 start_handler: Option<Box<StartHandler>>,
135 exit_handler: Option<Box<ExitHandler>>,
136
137 terminate_count: AtomicUsize,
151}
152
153static mut THE_REGISTRY: Option<Arc<Registry>> = None;
157static THE_REGISTRY_SET: Once = Once::new();
158
159pub(super) fn global_registry() -> &'static Arc<Registry> {
163 set_global_registry(default_global_registry)
164 .or_else(|err| {
165 debug_assert!(THE_REGISTRY_SET.is_completed());
168 let the_registry = unsafe { &*ptr::addr_of!(THE_REGISTRY) };
169 the_registry.as_ref().ok_or(err)
170 })
171 .expect("The global thread pool has not been initialized.")
172}
173
174pub(super) fn init_global_registry<S>(
177 builder: ThreadPoolBuilder<S>,
178) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
179where
180 S: ThreadSpawn,
181{
182 set_global_registry(|| Registry::new(builder))
183}
184
185fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
188where
189 F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
190{
191 let mut result = Err(ThreadPoolBuildError::new(
192 ErrorKind::GlobalPoolAlreadyInitialized,
193 ));
194
195 THE_REGISTRY_SET.call_once(|| {
196 result = registry().map(|registry: Arc<Registry>| {
197 unsafe {
200 ptr::addr_of_mut!(THE_REGISTRY).write(Some(registry));
201 (*ptr::addr_of!(THE_REGISTRY)).as_ref().unwrap_unchecked()
202 }
203 })
204 });
205
206 result
207}
208
209fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
210 let result = Registry::new(ThreadPoolBuilder::new());
211
212 let unsupported = matches!(&result, Err(e) if e.is_unsupported());
219 if unsupported && WorkerThread::current().is_null() {
220 let builder = ThreadPoolBuilder::new().num_threads(1).use_current_thread();
221 let fallback_result = Registry::new(builder);
222 if fallback_result.is_ok() {
223 return fallback_result;
224 }
225 }
226
227 result
228}
229
230struct Terminator<'a>(&'a Arc<Registry>);
231
232impl<'a> Drop for Terminator<'a> {
233 fn drop(&mut self) {
234 self.0.terminate()
235 }
236}
237
238impl Registry {
239 pub(super) fn new<S>(
240 mut builder: ThreadPoolBuilder<S>,
241 ) -> Result<Arc<Self>, ThreadPoolBuildError>
242 where
243 S: ThreadSpawn,
244 {
245 let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
247
248 let breadth_first = builder.get_breadth_first();
249
250 let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
251 .map(|_| {
252 let worker = if breadth_first {
253 Worker::new_fifo()
254 } else {
255 Worker::new_lifo()
256 };
257
258 let stealer = worker.stealer();
259 (worker, stealer)
260 })
261 .unzip();
262
263 let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
264 .map(|_| {
265 let worker = Worker::new_fifo();
266 let stealer = worker.stealer();
267 (worker, stealer)
268 })
269 .unzip();
270
271 let registry = Arc::new(Registry {
272 thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
273 sleep: Sleep::new(n_threads),
274 injected_jobs: Injector::new(),
275 broadcasts: Mutex::new(broadcasts),
276 terminate_count: AtomicUsize::new(1),
277 panic_handler: builder.take_panic_handler(),
278 start_handler: builder.take_start_handler(),
279 exit_handler: builder.take_exit_handler(),
280 });
281
282 let t1000 = Terminator(®istry);
284
285 for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
286 let thread = ThreadBuilder {
287 name: builder.get_thread_name(index),
288 stack_size: builder.get_stack_size(),
289 registry: Arc::clone(®istry),
290 worker,
291 stealer,
292 index,
293 };
294
295 if index == 0 && builder.use_current_thread {
296 if !WorkerThread::current().is_null() {
297 return Err(ThreadPoolBuildError::new(
298 ErrorKind::CurrentThreadAlreadyInPool,
299 ));
300 }
301 let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread)));
305
306 unsafe {
307 WorkerThread::set_current(worker_thread);
308 Latch::set(®istry.thread_infos[index].primed);
309 }
310 continue;
311 }
312
313 if let Err(e) = builder.get_spawn_handler().spawn(thread) {
314 return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
315 }
316 }
317
318 mem::forget(t1000);
320
321 Ok(registry)
322 }
323
324 pub(super) fn current() -> Arc<Registry> {
325 unsafe {
326 let worker_thread = WorkerThread::current();
327 let registry = if worker_thread.is_null() {
328 global_registry()
329 } else {
330 &(*worker_thread).registry
331 };
332 Arc::clone(registry)
333 }
334 }
335
336 pub(super) fn current_num_threads() -> usize {
340 unsafe {
341 let worker_thread = WorkerThread::current();
342 if worker_thread.is_null() {
343 global_registry().num_threads()
344 } else {
345 (*worker_thread).registry.num_threads()
346 }
347 }
348 }
349
350 pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
352 unsafe {
353 let worker = WorkerThread::current().as_ref()?;
354 if worker.registry().id() == self.id() {
355 Some(worker)
356 } else {
357 None
358 }
359 }
360 }
361
362 pub(super) fn id(&self) -> RegistryId {
364 RegistryId {
367 addr: self as *const Self as usize,
368 }
369 }
370
371 pub(super) fn num_threads(&self) -> usize {
372 self.thread_infos.len()
373 }
374
375 pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
376 if let Err(err) = unwind::halt_unwinding(f) {
377 let abort_guard = unwind::AbortIfPanic;
379 if let Some(ref handler) = self.panic_handler {
380 handler(err);
381 mem::forget(abort_guard);
382 }
383 }
384 }
385
386 pub(super) fn wait_until_primed(&self) {
391 for info in &self.thread_infos {
392 info.primed.wait();
393 }
394 }
395
396 #[cfg(test)]
399 pub(super) fn wait_until_stopped(&self) {
400 for info in &self.thread_infos {
401 info.stopped.wait();
402 }
403 }
404
405 pub(super) fn inject_or_push(&self, job_ref: JobRef) {
415 let worker_thread = WorkerThread::current();
416 unsafe {
417 if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
418 (*worker_thread).push(job_ref);
419 } else {
420 self.inject(job_ref);
421 }
422 }
423 }
424
425 pub(super) fn inject(&self, injected_job: JobRef) {
429 debug_assert_ne!(
435 self.terminate_count.load(Ordering::Acquire),
436 0,
437 "inject() sees state.terminate as true"
438 );
439
440 let queue_was_empty = self.injected_jobs.is_empty();
441
442 self.injected_jobs.push(injected_job);
443 self.sleep.new_injected_jobs(1, queue_was_empty);
444 }
445
446 fn has_injected_job(&self) -> bool {
447 !self.injected_jobs.is_empty()
448 }
449
450 fn pop_injected_job(&self) -> Option<JobRef> {
451 loop {
452 match self.injected_jobs.steal() {
453 Steal::Success(job) => return Some(job),
454 Steal::Empty => return None,
455 Steal::Retry => {}
456 }
457 }
458 }
459
460 pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
466 assert_eq!(self.num_threads(), injected_jobs.len());
467 {
468 let broadcasts = self.broadcasts.lock().unwrap();
469
470 debug_assert_ne!(
476 self.terminate_count.load(Ordering::Acquire),
477 0,
478 "inject_broadcast() sees state.terminate as true"
479 );
480
481 assert_eq!(broadcasts.len(), injected_jobs.len());
482 for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
483 worker.push(job_ref);
484 }
485 }
486 for i in 0..self.num_threads() {
487 self.sleep.notify_worker_latch_is_set(i);
488 }
489 }
490
491 pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
497 where
498 OP: FnOnce(&WorkerThread, bool) -> R + Send,
499 R: Send,
500 {
501 unsafe {
502 let worker_thread = WorkerThread::current();
503 if worker_thread.is_null() {
504 self.in_worker_cold(op)
505 } else if (*worker_thread).registry().id() != self.id() {
506 self.in_worker_cross(&*worker_thread, op)
507 } else {
508 op(&*worker_thread, false)
512 }
513 }
514 }
515
516 #[cold]
517 unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
518 where
519 OP: FnOnce(&WorkerThread, bool) -> R + Send,
520 R: Send,
521 {
522 thread_local!(static LOCK_LATCH: LockLatch = const { LockLatch::new() });
523
524 LOCK_LATCH.with(|l| {
525 debug_assert!(WorkerThread::current().is_null());
527 let job = StackJob::new(
528 |injected| {
529 let worker_thread = WorkerThread::current();
530 assert!(injected && !worker_thread.is_null());
531 op(&*worker_thread, true)
532 },
533 LatchRef::new(l),
534 );
535 self.inject(job.as_job_ref());
536 job.latch.wait_and_reset(); job.into_result()
539 })
540 }
541
542 #[cold]
543 unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
544 where
545 OP: FnOnce(&WorkerThread, bool) -> R + Send,
546 R: Send,
547 {
548 debug_assert!(current_thread.registry().id() != self.id());
551 let latch = SpinLatch::cross(current_thread);
552 let job = StackJob::new(
553 |injected| {
554 let worker_thread = WorkerThread::current();
555 assert!(injected && !worker_thread.is_null());
556 op(&*worker_thread, true)
557 },
558 latch,
559 );
560 self.inject(job.as_job_ref());
561 current_thread.wait_until(&job.latch);
562 job.into_result()
563 }
564
565 pub(super) fn increment_terminate_count(&self) {
586 let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
587 debug_assert!(previous != 0, "registry ref count incremented from zero");
588 assert!(previous != usize::MAX, "overflow in registry ref count");
589 }
590
591 pub(super) fn terminate(&self) {
595 if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
596 for (i, thread_info) in self.thread_infos.iter().enumerate() {
597 unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
598 }
599 }
600 }
601
602 pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
604 self.sleep.notify_worker_latch_is_set(target_worker_index);
605 }
606}
607
608#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
609pub(super) struct RegistryId {
610 addr: usize,
611}
612
613struct ThreadInfo {
614 primed: LockLatch,
618
619 stopped: LockLatch,
622
623 terminate: OnceLatch,
628
629 stealer: Stealer<JobRef>,
631}
632
633impl ThreadInfo {
634 fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
635 ThreadInfo {
636 primed: LockLatch::new(),
637 stopped: LockLatch::new(),
638 terminate: OnceLatch::new(),
639 stealer,
640 }
641 }
642}
643
644pub(super) struct WorkerThread {
648 worker: Worker<JobRef>,
650
651 stealer: Stealer<JobRef>,
653
654 fifo: JobFifo,
656
657 index: usize,
658
659 rng: XorShift64Star,
661
662 registry: Arc<Registry>,
663}
664
665thread_local! {
671 static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
672}
673
674impl From<ThreadBuilder> for WorkerThread {
675 fn from(thread: ThreadBuilder) -> Self {
676 Self {
677 worker: thread.worker,
678 stealer: thread.stealer,
679 fifo: JobFifo::new(),
680 index: thread.index,
681 rng: XorShift64Star::new(),
682 registry: thread.registry,
683 }
684 }
685}
686
687impl Drop for WorkerThread {
688 fn drop(&mut self) {
689 WORKER_THREAD_STATE.with(|t| {
691 assert!(t.get().eq(&(self as *const _)));
692 t.set(ptr::null());
693 });
694 }
695}
696
697impl WorkerThread {
698 #[inline]
702 pub(super) fn current() -> *const WorkerThread {
703 WORKER_THREAD_STATE.get()
704 }
705
706 unsafe fn set_current(thread: *const WorkerThread) {
709 WORKER_THREAD_STATE.with(|t| {
710 assert!(t.get().is_null());
711 t.set(thread);
712 });
713 }
714
715 #[inline]
717 pub(super) fn registry(&self) -> &Arc<Registry> {
718 &self.registry
719 }
720
721 #[inline]
723 pub(super) fn index(&self) -> usize {
724 self.index
725 }
726
727 #[inline]
728 pub(super) unsafe fn push(&self, job: JobRef) {
729 let queue_was_empty = self.worker.is_empty();
730 self.worker.push(job);
731 self.registry.sleep.new_internal_jobs(1, queue_was_empty);
732 }
733
734 #[inline]
735 pub(super) unsafe fn push_fifo(&self, job: JobRef) {
736 self.push(self.fifo.push(job));
737 }
738
739 #[inline]
740 pub(super) fn local_deque_is_empty(&self) -> bool {
741 self.worker.is_empty()
742 }
743
744 #[inline]
749 pub(super) fn take_local_job(&self) -> Option<JobRef> {
750 let popped_job = self.worker.pop();
751
752 if popped_job.is_some() {
753 return popped_job;
754 }
755
756 loop {
757 match self.stealer.steal() {
758 Steal::Success(job) => return Some(job),
759 Steal::Empty => return None,
760 Steal::Retry => {}
761 }
762 }
763 }
764
765 fn has_injected_job(&self) -> bool {
766 !self.stealer.is_empty() || self.registry.has_injected_job()
767 }
768
769 #[inline]
772 pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
773 let latch = latch.as_core_latch();
774 if !latch.probe() {
775 self.wait_until_cold(latch);
776 }
777 }
778
779 #[cold]
780 unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
781 let abort_guard = unwind::AbortIfPanic;
787
788 'outer: while !latch.probe() {
789 if let Some(job) = self.take_local_job() {
792 self.execute(job);
793 continue;
794 }
795
796 let mut idle_state = self.registry.sleep.start_looking(self.index);
797 while !latch.probe() {
798 if let Some(job) = self.find_work() {
799 self.registry.sleep.work_found();
800 self.execute(job);
801 continue 'outer;
803 } else {
804 self.registry
805 .sleep
806 .no_work_found(&mut idle_state, latch, || self.has_injected_job())
807 }
808 }
809
810 self.registry.sleep.work_found();
813 break;
814 }
815
816 mem::forget(abort_guard); }
818
819 unsafe fn wait_until_out_of_work(&self) {
820 debug_assert_eq!(self as *const _, WorkerThread::current());
821 let registry = &*self.registry;
822 let index = self.index;
823
824 self.wait_until(®istry.thread_infos[index].terminate);
825
826 debug_assert!(self.take_local_job().is_none());
828
829 Latch::set(®istry.thread_infos[index].stopped);
831 }
832
833 fn find_work(&self) -> Option<JobRef> {
834 self.take_local_job()
840 .or_else(|| self.steal())
841 .or_else(|| self.registry.pop_injected_job())
842 }
843
844 pub(super) fn yield_now(&self) -> Yield {
845 match self.find_work() {
846 Some(job) => unsafe {
847 self.execute(job);
848 Yield::Executed
849 },
850 None => Yield::Idle,
851 }
852 }
853
854 pub(super) fn yield_local(&self) -> Yield {
855 match self.take_local_job() {
856 Some(job) => unsafe {
857 self.execute(job);
858 Yield::Executed
859 },
860 None => Yield::Idle,
861 }
862 }
863
864 #[inline]
865 pub(super) unsafe fn execute(&self, job: JobRef) {
866 job.execute();
867 }
868
869 fn steal(&self) -> Option<JobRef> {
874 debug_assert!(self.local_deque_is_empty());
876
877 let thread_infos = &self.registry.thread_infos.as_slice();
879 let num_threads = thread_infos.len();
880 if num_threads <= 1 {
881 return None;
882 }
883
884 loop {
885 let mut retry = false;
886 let start = self.rng.next_usize(num_threads);
887 let job = (start..num_threads)
888 .chain(0..start)
889 .filter(move |&i| i != self.index)
890 .find_map(|victim_index| {
891 let victim = &thread_infos[victim_index];
892 match victim.stealer.steal() {
893 Steal::Success(job) => Some(job),
894 Steal::Empty => None,
895 Steal::Retry => {
896 retry = true;
897 None
898 }
899 }
900 });
901 if job.is_some() || !retry {
902 return job;
903 }
904 }
905 }
906}
907
908unsafe fn main_loop(thread: ThreadBuilder) {
911 let worker_thread = &WorkerThread::from(thread);
912 WorkerThread::set_current(worker_thread);
913 let registry = &*worker_thread.registry;
914 let index = worker_thread.index;
915
916 Latch::set(®istry.thread_infos[index].primed);
918
919 let abort_guard = unwind::AbortIfPanic;
923
924 if let Some(ref handler) = registry.start_handler {
926 registry.catch_unwind(|| handler(index));
927 }
928
929 worker_thread.wait_until_out_of_work();
930
931 mem::forget(abort_guard);
933
934 if let Some(ref handler) = registry.exit_handler {
936 registry.catch_unwind(|| handler(index));
937 }
939}
940
941pub(super) fn in_worker<OP, R>(op: OP) -> R
947where
948 OP: FnOnce(&WorkerThread, bool) -> R + Send,
949 R: Send,
950{
951 unsafe {
952 let owner_thread = WorkerThread::current();
953 if !owner_thread.is_null() {
954 op(&*owner_thread, false)
958 } else {
959 global_registry().in_worker(op)
960 }
961 }
962}
963
964struct XorShift64Star {
969 state: Cell<u64>,
970}
971
972impl XorShift64Star {
973 fn new() -> Self {
974 let mut seed = 0;
976 while seed == 0 {
977 let mut hasher = DefaultHasher::new();
978 static COUNTER: AtomicUsize = AtomicUsize::new(0);
979 hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
980 seed = hasher.finish();
981 }
982
983 XorShift64Star {
984 state: Cell::new(seed),
985 }
986 }
987
988 fn next(&self) -> u64 {
989 let mut x = self.state.get();
990 debug_assert_ne!(x, 0);
991 x ^= x >> 12;
992 x ^= x << 25;
993 x ^= x >> 27;
994 self.state.set(x);
995 x.wrapping_mul(0x2545_f491_4f6c_dd1d)
996 }
997
998 fn next_usize(&self, n: usize) -> usize {
1000 (self.next() % n as u64) as usize
1001 }
1002}