rayon/slice/sort.rs
1//! **Parallel** Slice sorting
2//!
3//! This implementation is mostly copied from the `core::slice::sort` module, with minimal changes
4//! to support stable Rust and parallel `is_less` (e.g. `Fn` rather than `FnMut`).
5//!
6//! ---
7//!
8//! This module contains a sorting algorithm based on Orson Peters' pattern-defeating quicksort,
9//! published at: <https://github.com/orlp/pdqsort>
10//!
11//! Unstable sorting is compatible with core because it doesn't allocate memory, unlike our
12//! stable sorting implementation.
13//!
14//! In addition it also contains the core logic of the stable sort used by `slice::sort` based on
15//! TimSort.
16
17use core::cmp;
18use core::mem::{self, MaybeUninit};
19use core::ptr;
20use core::slice;
21
22use crate::iter::{IndexedParallelIterator, ParallelIterator};
23use crate::slice::ParallelSliceMut;
24use crate::SendPtr;
25
26// When dropped, copies from `src` into `dest`.
27struct InsertionHole<T> {
28 src: *const T,
29 dest: *mut T,
30}
31
32impl<T> Drop for InsertionHole<T> {
33 fn drop(&mut self) {
34 // SAFETY: This is a helper class. Please refer to its usage for correctness. Namely, one
35 // must be sure that `src` and `dst` does not overlap as required by
36 // `ptr::copy_nonoverlapping` and are both valid for writes.
37 unsafe {
38 ptr::copy_nonoverlapping(self.src, self.dest, 1);
39 }
40 }
41}
42
43/// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]`
44/// becomes sorted.
45unsafe fn insert_tail<T, F>(v: &mut [T], is_less: &F)
46where
47 F: Fn(&T, &T) -> bool,
48{
49 debug_assert!(v.len() >= 2);
50
51 let arr_ptr = v.as_mut_ptr();
52 let i = v.len() - 1;
53
54 // SAFETY: caller must ensure v is at least len 2.
55 unsafe {
56 // See insert_head which talks about why this approach is beneficial.
57 let i_ptr = arr_ptr.add(i);
58
59 // It's important that we use i_ptr here. If this check is positive and we continue,
60 // We want to make sure that no other copy of the value was seen by is_less.
61 // Otherwise we would have to copy it back.
62 if is_less(&*i_ptr, &*i_ptr.sub(1)) {
63 // It's important, that we use tmp for comparison from now on. As it is the value that
64 // will be copied back. And notionally we could have created a divergence if we copy
65 // back the wrong value.
66 let tmp = mem::ManuallyDrop::new(ptr::read(i_ptr));
67 // Intermediate state of the insertion process is always tracked by `hole`, which
68 // serves two purposes:
69 // 1. Protects integrity of `v` from panics in `is_less`.
70 // 2. Fills the remaining hole in `v` in the end.
71 //
72 // Panic safety:
73 //
74 // If `is_less` panics at any point during the process, `hole` will get dropped and
75 // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
76 // initially held exactly once.
77 let mut hole = InsertionHole {
78 src: &*tmp,
79 dest: i_ptr.sub(1),
80 };
81 ptr::copy_nonoverlapping(hole.dest, i_ptr, 1);
82
83 // SAFETY: We know i is at least 1.
84 for j in (0..(i - 1)).rev() {
85 let j_ptr = arr_ptr.add(j);
86 if !is_less(&*tmp, &*j_ptr) {
87 break;
88 }
89
90 ptr::copy_nonoverlapping(j_ptr, hole.dest, 1);
91 hole.dest = j_ptr;
92 }
93 // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
94 }
95 }
96}
97
98/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted.
99///
100/// This is the integral subroutine of insertion sort.
101unsafe fn insert_head<T, F>(v: &mut [T], is_less: &F)
102where
103 F: Fn(&T, &T) -> bool,
104{
105 debug_assert!(v.len() >= 2);
106
107 // SAFETY: caller must ensure v is at least len 2.
108 unsafe {
109 if is_less(v.get_unchecked(1), v.get_unchecked(0)) {
110 let arr_ptr = v.as_mut_ptr();
111
112 // There are three ways to implement insertion here:
113 //
114 // 1. Swap adjacent elements until the first one gets to its final destination.
115 // However, this way we copy data around more than is necessary. If elements are big
116 // structures (costly to copy), this method will be slow.
117 //
118 // 2. Iterate until the right place for the first element is found. Then shift the
119 // elements succeeding it to make room for it and finally place it into the
120 // remaining hole. This is a good method.
121 //
122 // 3. Copy the first element into a temporary variable. Iterate until the right place
123 // for it is found. As we go along, copy every traversed element into the slot
124 // preceding it. Finally, copy data from the temporary variable into the remaining
125 // hole. This method is very good. Benchmarks demonstrated slightly better
126 // performance than with the 2nd method.
127 //
128 // All methods were benchmarked, and the 3rd showed best results. So we chose that one.
129 let tmp = mem::ManuallyDrop::new(ptr::read(arr_ptr));
130
131 // Intermediate state of the insertion process is always tracked by `hole`, which
132 // serves two purposes:
133 // 1. Protects integrity of `v` from panics in `is_less`.
134 // 2. Fills the remaining hole in `v` in the end.
135 //
136 // Panic safety:
137 //
138 // If `is_less` panics at any point during the process, `hole` will get dropped and
139 // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
140 // initially held exactly once.
141 let mut hole = InsertionHole {
142 src: &*tmp,
143 dest: arr_ptr.add(1),
144 };
145 ptr::copy_nonoverlapping(arr_ptr.add(1), arr_ptr.add(0), 1);
146
147 for i in 2..v.len() {
148 if !is_less(v.get_unchecked(i), &*tmp) {
149 break;
150 }
151 ptr::copy_nonoverlapping(arr_ptr.add(i), arr_ptr.add(i - 1), 1);
152 hole.dest = arr_ptr.add(i);
153 }
154 // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
155 }
156 }
157}
158
159/// Sort `v` assuming `v[..offset]` is already sorted.
160///
161/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
162/// performance impact. Even improving performance in some cases.
163#[inline(never)]
164fn insertion_sort_shift_left<T, F>(v: &mut [T], offset: usize, is_less: &F)
165where
166 F: Fn(&T, &T) -> bool,
167{
168 let len = v.len();
169
170 // Using assert here improves performance.
171 assert!(offset != 0 && offset <= len);
172
173 // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted.
174 for i in offset..len {
175 // SAFETY: we tested that `offset` must be at least 1, so this loop is only entered if len
176 // >= 2. The range is exclusive and we know `i` must be at least 1 so this slice has at
177 // >least len 2.
178 unsafe {
179 insert_tail(&mut v[..=i], is_less);
180 }
181 }
182}
183
184/// Sort `v` assuming `v[offset..]` is already sorted.
185///
186/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
187/// performance impact. Even improving performance in some cases.
188#[inline(never)]
189fn insertion_sort_shift_right<T, F>(v: &mut [T], offset: usize, is_less: &F)
190where
191 F: Fn(&T, &T) -> bool,
192{
193 let len = v.len();
194
195 // Using assert here improves performance.
196 assert!(offset != 0 && offset <= len && len >= 2);
197
198 // Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted.
199 for i in (0..offset).rev() {
200 // SAFETY: we tested that `offset` must be at least 1, so this loop is only entered if len
201 // >= 2.We ensured that the slice length is always at least 2 long. We know that start_found
202 // will be at least one less than end, and the range is exclusive. Which gives us i always
203 // <= (end - 2).
204 unsafe {
205 insert_head(&mut v[i..len], is_less);
206 }
207 }
208}
209
210/// Partially sorts a slice by shifting several out-of-order elements around.
211///
212/// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case.
213#[cold]
214fn partial_insertion_sort<T, F>(v: &mut [T], is_less: &F) -> bool
215where
216 F: Fn(&T, &T) -> bool,
217{
218 // Maximum number of adjacent out-of-order pairs that will get shifted.
219 const MAX_STEPS: usize = 5;
220 // If the slice is shorter than this, don't shift any elements.
221 const SHORTEST_SHIFTING: usize = 50;
222
223 let len = v.len();
224 let mut i = 1;
225
226 for _ in 0..MAX_STEPS {
227 // SAFETY: We already explicitly did the bound checking with `i < len`.
228 // All our subsequent indexing is only in the range `0 <= index < len`
229 unsafe {
230 // Find the next pair of adjacent out-of-order elements.
231 while i < len && !is_less(v.get_unchecked(i), v.get_unchecked(i - 1)) {
232 i += 1;
233 }
234 }
235
236 // Are we done?
237 if i == len {
238 return true;
239 }
240
241 // Don't shift elements on short arrays, that has a performance cost.
242 if len < SHORTEST_SHIFTING {
243 return false;
244 }
245
246 // Swap the found pair of elements. This puts them in correct order.
247 v.swap(i - 1, i);
248
249 if i >= 2 {
250 // Shift the smaller element to the left.
251 insertion_sort_shift_left(&mut v[..i], i - 1, is_less);
252
253 // Shift the greater element to the right.
254 insertion_sort_shift_right(&mut v[..i], 1, is_less);
255 }
256 }
257
258 // Didn't manage to sort the slice in the limited number of steps.
259 false
260}
261
262/// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case.
263#[cold]
264fn heapsort<T, F>(v: &mut [T], is_less: F)
265where
266 F: Fn(&T, &T) -> bool,
267{
268 // This binary heap respects the invariant `parent >= child`.
269 let sift_down = |v: &mut [T], mut node| {
270 loop {
271 // Children of `node`.
272 let mut child = 2 * node + 1;
273 if child >= v.len() {
274 break;
275 }
276
277 // Choose the greater child.
278 if child + 1 < v.len() {
279 // We need a branch to be sure not to out-of-bounds index,
280 // but it's highly predictable. The comparison, however,
281 // is better done branchless, especially for primitives.
282 child += is_less(&v[child], &v[child + 1]) as usize;
283 }
284
285 // Stop if the invariant holds at `node`.
286 if !is_less(&v[node], &v[child]) {
287 break;
288 }
289
290 // Swap `node` with the greater child, move one step down, and continue sifting.
291 v.swap(node, child);
292 node = child;
293 }
294 };
295
296 // Build the heap in linear time.
297 for i in (0..v.len() / 2).rev() {
298 sift_down(v, i);
299 }
300
301 // Pop maximal elements from the heap.
302 for i in (1..v.len()).rev() {
303 v.swap(0, i);
304 sift_down(&mut v[..i], 0);
305 }
306}
307
308/// Partitions `v` into elements smaller than `pivot`, followed by elements greater than or equal
309/// to `pivot`.
310///
311/// Returns the number of elements smaller than `pivot`.
312///
313/// Partitioning is performed block-by-block in order to minimize the cost of branching operations.
314/// This idea is presented in the [BlockQuicksort][pdf] paper.
315///
316/// [pdf]: https://drops.dagstuhl.de/opus/volltexte/2016/6389/pdf/LIPIcs-ESA-2016-38.pdf
317fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &F) -> usize
318where
319 F: Fn(&T, &T) -> bool,
320{
321 // Number of elements in a typical block.
322 const BLOCK: usize = 128;
323
324 // The partitioning algorithm repeats the following steps until completion:
325 //
326 // 1. Trace a block from the left side to identify elements greater than or equal to the pivot.
327 // 2. Trace a block from the right side to identify elements smaller than the pivot.
328 // 3. Exchange the identified elements between the left and right side.
329 //
330 // We keep the following variables for a block of elements:
331 //
332 // 1. `block` - Number of elements in the block.
333 // 2. `start` - Start pointer into the `offsets` array.
334 // 3. `end` - End pointer into the `offsets` array.
335 // 4. `offsets` - Indices of out-of-order elements within the block.
336
337 // The current block on the left side (from `l` to `l.add(block_l)`).
338 let mut l = v.as_mut_ptr();
339 let mut block_l = BLOCK;
340 let mut start_l = ptr::null_mut();
341 let mut end_l = ptr::null_mut();
342 let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
343
344 // The current block on the right side (from `r.sub(block_r)` to `r`).
345 // SAFETY: The documentation for .add() specifically mention that `vec.as_ptr().add(vec.len())` is always safe
346 let mut r = unsafe { l.add(v.len()) };
347 let mut block_r = BLOCK;
348 let mut start_r = ptr::null_mut();
349 let mut end_r = ptr::null_mut();
350 let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
351
352 // FIXME: When we get VLAs, try creating one array of length `min(v.len(), 2 * BLOCK)` rather
353 // than two fixed-size arrays of length `BLOCK`. VLAs might be more cache-efficient.
354
355 // Returns the number of elements between pointers `l` (inclusive) and `r` (exclusive).
356 fn width<T>(l: *mut T, r: *mut T) -> usize {
357 assert!(size_of::<T>() > 0);
358 // FIXME: this should *likely* use `offset_from`, but more
359 // investigation is needed (including running tests in miri).
360 (r as usize - l as usize) / size_of::<T>()
361 }
362
363 loop {
364 // We are done with partitioning block-by-block when `l` and `r` get very close. Then we do
365 // some patch-up work in order to partition the remaining elements in between.
366 let is_done = width(l, r) <= 2 * BLOCK;
367
368 if is_done {
369 // Number of remaining elements (still not compared to the pivot).
370 let mut rem = width(l, r);
371 if start_l < end_l || start_r < end_r {
372 rem -= BLOCK;
373 }
374
375 // Adjust block sizes so that the left and right block don't overlap, but get perfectly
376 // aligned to cover the whole remaining gap.
377 if start_l < end_l {
378 block_r = rem;
379 } else if start_r < end_r {
380 block_l = rem;
381 } else {
382 // There were the same number of elements to switch on both blocks during the last
383 // iteration, so there are no remaining elements on either block. Cover the remaining
384 // items with roughly equally-sized blocks.
385 block_l = rem / 2;
386 block_r = rem - block_l;
387 }
388 debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
389 debug_assert!(width(l, r) == block_l + block_r);
390 }
391
392 if start_l == end_l {
393 // Trace `block_l` elements from the left side.
394 start_l = offsets_l.as_mut_ptr() as *mut u8;
395 end_l = start_l;
396 let mut elem = l;
397
398 for i in 0..block_l {
399 // SAFETY: The unsafety operations below involve the usage of the `offset`.
400 // According to the conditions required by the function, we satisfy them because:
401 // 1. `offsets_l` is stack-allocated, and thus considered separate allocated object.
402 // 2. The function `is_less` returns a `bool`.
403 // Casting a `bool` will never overflow `isize`.
404 // 3. We have guaranteed that `block_l` will be `<= BLOCK`.
405 // Plus, `end_l` was initially set to the begin pointer of `offsets_` which was declared on the stack.
406 // Thus, we know that even in the worst case (all invocations of `is_less` returns false) we will only be at most 1 byte pass the end.
407 // Another unsafety operation here is dereferencing `elem`.
408 // However, `elem` was initially the begin pointer to the slice which is always valid.
409 unsafe {
410 // Branchless comparison.
411 *end_l = i as u8;
412 end_l = end_l.add(!is_less(&*elem, pivot) as usize);
413 elem = elem.add(1);
414 }
415 }
416 }
417
418 if start_r == end_r {
419 // Trace `block_r` elements from the right side.
420 start_r = offsets_r.as_mut_ptr() as *mut u8;
421 end_r = start_r;
422 let mut elem = r;
423
424 for i in 0..block_r {
425 // SAFETY: The unsafety operations below involve the usage of the `offset`.
426 // According to the conditions required by the function, we satisfy them because:
427 // 1. `offsets_r` is stack-allocated, and thus considered separate allocated object.
428 // 2. The function `is_less` returns a `bool`.
429 // Casting a `bool` will never overflow `isize`.
430 // 3. We have guaranteed that `block_r` will be `<= BLOCK`.
431 // Plus, `end_r` was initially set to the begin pointer of `offsets_` which was declared on the stack.
432 // Thus, we know that even in the worst case (all invocations of `is_less` returns true) we will only be at most 1 byte pass the end.
433 // Another unsafety operation here is dereferencing `elem`.
434 // However, `elem` was initially `1 * sizeof(T)` past the end and we decrement it by `1 * sizeof(T)` before accessing it.
435 // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice.
436 unsafe {
437 // Branchless comparison.
438 elem = elem.sub(1);
439 *end_r = i as u8;
440 end_r = end_r.add(is_less(&*elem, pivot) as usize);
441 }
442 }
443 }
444
445 // Number of out-of-order elements to swap between the left and right side.
446 let count = cmp::min(width(start_l, end_l), width(start_r, end_r));
447
448 if count > 0 {
449 macro_rules! left {
450 () => {
451 l.add(usize::from(*start_l))
452 };
453 }
454 macro_rules! right {
455 () => {
456 r.sub(usize::from(*start_r) + 1)
457 };
458 }
459
460 // Instead of swapping one pair at the time, it is more efficient to perform a cyclic
461 // permutation. This is not strictly equivalent to swapping, but produces a similar
462 // result using fewer memory operations.
463
464 // SAFETY: The use of `ptr::read` is valid because there is at least one element in
465 // both `offsets_l` and `offsets_r`, so `left!` is a valid pointer to read from.
466 //
467 // The uses of `left!` involve calls to `offset` on `l`, which points to the
468 // beginning of `v`. All the offsets pointed-to by `start_l` are at most `block_l`, so
469 // these `offset` calls are safe as all reads are within the block. The same argument
470 // applies for the uses of `right!`.
471 //
472 // The calls to `start_l.offset` are valid because there are at most `count-1` of them,
473 // plus the final one at the end of the unsafe block, where `count` is the minimum number
474 // of collected offsets in `offsets_l` and `offsets_r`, so there is no risk of there not
475 // being enough elements. The same reasoning applies to the calls to `start_r.offset`.
476 //
477 // The calls to `copy_nonoverlapping` are safe because `left!` and `right!` are guaranteed
478 // not to overlap, and are valid because of the reasoning above.
479 unsafe {
480 let tmp = ptr::read(left!());
481 ptr::copy_nonoverlapping(right!(), left!(), 1);
482
483 for _ in 1..count {
484 start_l = start_l.add(1);
485 ptr::copy_nonoverlapping(left!(), right!(), 1);
486 start_r = start_r.add(1);
487 ptr::copy_nonoverlapping(right!(), left!(), 1);
488 }
489
490 ptr::copy_nonoverlapping(&tmp, right!(), 1);
491 mem::forget(tmp);
492 start_l = start_l.add(1);
493 start_r = start_r.add(1);
494 }
495 }
496
497 if start_l == end_l {
498 // All out-of-order elements in the left block were moved. Move to the next block.
499
500 // block-width-guarantee
501 // SAFETY: if `!is_done` then the slice width is guaranteed to be at least `2*BLOCK` wide. There
502 // are at most `BLOCK` elements in `offsets_l` because of its size, so the `offset` operation is
503 // safe. Otherwise, the debug assertions in the `is_done` case guarantee that
504 // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account
505 // for the smaller number of remaining elements.
506 l = unsafe { l.add(block_l) };
507 }
508
509 if start_r == end_r {
510 // All out-of-order elements in the right block were moved. Move to the previous block.
511
512 // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide,
513 // or `block_r` has been adjusted for the last handful of elements.
514 r = unsafe { r.sub(block_r) };
515 }
516
517 if is_done {
518 break;
519 }
520 }
521
522 // All that remains now is at most one block (either the left or the right) with out-of-order
523 // elements that need to be moved. Such remaining elements can be simply shifted to the end
524 // within their block.
525
526 if start_l < end_l {
527 // The left block remains.
528 // Move its remaining out-of-order elements to the far right.
529 debug_assert_eq!(width(l, r), block_l);
530 while start_l < end_l {
531 // remaining-elements-safety
532 // SAFETY: while the loop condition holds there are still elements in `offsets_l`, so it
533 // is safe to point `end_l` to the previous element.
534 //
535 // The `ptr::swap` is safe if both its arguments are valid for reads and writes:
536 // - Per the debug assert above, the distance between `l` and `r` is `block_l`
537 // elements, so there can be at most `block_l` remaining offsets between `start_l`
538 // and `end_l`. This means `r` will be moved at most `block_l` steps back, which
539 // makes the `r.offset` calls valid (at that point `l == r`).
540 // - `offsets_l` contains valid offsets into `v` collected during the partitioning of
541 // the last block, so the `l.offset` calls are valid.
542 unsafe {
543 end_l = end_l.sub(1);
544 ptr::swap(l.add(usize::from(*end_l)), r.sub(1));
545 r = r.sub(1);
546 }
547 }
548 width(v.as_mut_ptr(), r)
549 } else if start_r < end_r {
550 // The right block remains.
551 // Move its remaining out-of-order elements to the far left.
552 debug_assert_eq!(width(l, r), block_r);
553 while start_r < end_r {
554 // SAFETY: See the reasoning in [remaining-elements-safety].
555 unsafe {
556 end_r = end_r.sub(1);
557 ptr::swap(l, r.sub(usize::from(*end_r) + 1));
558 l = l.add(1);
559 }
560 }
561 width(v.as_mut_ptr(), l)
562 } else {
563 // Nothing else to do, we're done.
564 width(v.as_mut_ptr(), l)
565 }
566}
567
568/// Partitions `v` into elements smaller than `v[pivot]`, followed by elements greater than or
569/// equal to `v[pivot]`.
570///
571/// Returns a tuple of:
572///
573/// 1. Number of elements smaller than `v[pivot]`.
574/// 2. True if `v` was already partitioned.
575fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> (usize, bool)
576where
577 F: Fn(&T, &T) -> bool,
578{
579 let (mid, was_partitioned) = {
580 // Place the pivot at the beginning of slice.
581 v.swap(0, pivot);
582 let (pivot, v) = v.split_at_mut(1);
583 let pivot = &mut pivot[0];
584
585 // Read the pivot into a stack-allocated variable for efficiency. If a following comparison
586 // operation panics, the pivot will be automatically written back into the slice.
587
588 // SAFETY: `pivot` is a reference to the first element of `v`, so `ptr::read` is safe.
589 let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
590 let _pivot_guard = InsertionHole {
591 src: &*tmp,
592 dest: pivot,
593 };
594 let pivot = &*tmp;
595
596 // Find the first pair of out-of-order elements.
597 let mut l = 0;
598 let mut r = v.len();
599
600 // SAFETY: The unsafety below involves indexing an array.
601 // For the first one: We already do the bounds checking here with `l < r`.
602 // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation.
603 // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one.
604 unsafe {
605 // Find the first element greater than or equal to the pivot.
606 while l < r && is_less(v.get_unchecked(l), pivot) {
607 l += 1;
608 }
609
610 // Find the last element smaller that the pivot.
611 while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
612 r -= 1;
613 }
614 }
615
616 (
617 l + partition_in_blocks(&mut v[l..r], pivot, is_less),
618 l >= r,
619 )
620
621 // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated
622 // variable) back into the slice where it originally was. This step is critical in ensuring
623 // safety!
624 };
625
626 // Place the pivot between the two partitions.
627 v.swap(0, mid);
628
629 (mid, was_partitioned)
630}
631
632/// Partitions `v` into elements equal to `v[pivot]` followed by elements greater than `v[pivot]`.
633///
634/// Returns the number of elements equal to the pivot. It is assumed that `v` does not contain
635/// elements smaller than the pivot.
636fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &F) -> usize
637where
638 F: Fn(&T, &T) -> bool,
639{
640 // Place the pivot at the beginning of slice.
641 v.swap(0, pivot);
642 let (pivot, v) = v.split_at_mut(1);
643 let pivot = &mut pivot[0];
644
645 // Read the pivot into a stack-allocated variable for efficiency. If a following comparison
646 // operation panics, the pivot will be automatically written back into the slice.
647 // SAFETY: The pointer here is valid because it is obtained from a reference to a slice.
648 let tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
649 let _pivot_guard = InsertionHole {
650 src: &*tmp,
651 dest: pivot,
652 };
653 let pivot = &*tmp;
654
655 let len = v.len();
656 if len == 0 {
657 return 0;
658 }
659
660 // Now partition the slice.
661 let mut l = 0;
662 let mut r = len;
663 loop {
664 // SAFETY: The unsafety below involves indexing an array.
665 // For the first one: We already do the bounds checking here with `l < r`.
666 // For the second one: We initially have `l == 0` and `r == v.len()` and we checked that `l < r` at every indexing operation.
667 // From here we know that `r` must be at least `r == l` which was shown to be valid from the first one.
668 unsafe {
669 // Find the first element greater than the pivot.
670 while l < r && !is_less(pivot, v.get_unchecked(l)) {
671 l += 1;
672 }
673
674 // Find the last element equal to the pivot.
675 loop {
676 r -= 1;
677 if l >= r || !is_less(pivot, v.get_unchecked(r)) {
678 break;
679 }
680 }
681
682 // Are we done?
683 if l >= r {
684 break;
685 }
686
687 // Swap the found pair of out-of-order elements.
688 let ptr = v.as_mut_ptr();
689 ptr::swap(ptr.add(l), ptr.add(r));
690 l += 1;
691 }
692 }
693
694 // We found `l` elements equal to the pivot. Add 1 to account for the pivot itself.
695 l + 1
696
697 // `_pivot_guard` goes out of scope and writes the pivot (which is a stack-allocated variable)
698 // back into the slice where it originally was. This step is critical in ensuring safety!
699}
700
701/// Scatters some elements around in an attempt to break patterns that might cause imbalanced
702/// partitions in quicksort.
703#[cold]
704fn break_patterns<T>(v: &mut [T]) {
705 let len = v.len();
706 if len >= 8 {
707 let mut seed = len;
708 let mut gen_usize = || {
709 // Pseudorandom number generator from the "Xorshift RNGs" paper by George Marsaglia.
710 if usize::BITS <= 32 {
711 let mut r = seed as u32;
712 r ^= r << 13;
713 r ^= r >> 17;
714 r ^= r << 5;
715 seed = r as usize;
716 seed
717 } else {
718 let mut r = seed as u64;
719 r ^= r << 13;
720 r ^= r >> 7;
721 r ^= r << 17;
722 seed = r as usize;
723 seed
724 }
725 };
726
727 // Take random numbers modulo this number.
728 // The number fits into `usize` because `len` is not greater than `isize::MAX`.
729 let modulus = len.next_power_of_two();
730
731 // Some pivot candidates will be in the nearby of this index. Let's randomize them.
732 let pos = len / 4 * 2;
733
734 for i in 0..3 {
735 // Generate a random number modulo `len`. However, in order to avoid costly operations
736 // we first take it modulo a power of two, and then decrease by `len` until it fits
737 // into the range `[0, len - 1]`.
738 let mut other = gen_usize() & (modulus - 1);
739
740 // `other` is guaranteed to be less than `2 * len`.
741 if other >= len {
742 other -= len;
743 }
744
745 v.swap(pos - 1 + i, other);
746 }
747 }
748}
749
750/// Chooses a pivot in `v` and returns the index and `true` if the slice is likely already sorted.
751///
752/// Elements in `v` might be reordered in the process.
753fn choose_pivot<T, F>(v: &mut [T], is_less: &F) -> (usize, bool)
754where
755 F: Fn(&T, &T) -> bool,
756{
757 // Minimum length to choose the median-of-medians method.
758 // Shorter slices use the simple median-of-three method.
759 const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
760 // Maximum number of swaps that can be performed in this function.
761 const MAX_SWAPS: usize = 4 * 3;
762
763 let len = v.len();
764
765 // Three indices near which we are going to choose a pivot.
766 #[allow(clippy::identity_op)]
767 let mut a = len / 4 * 1;
768 let mut b = len / 4 * 2;
769 let mut c = len / 4 * 3;
770
771 // Counts the total number of swaps we are about to perform while sorting indices.
772 let mut swaps = 0;
773
774 if len >= 8 {
775 // Swaps indices so that `v[a] <= v[b]`.
776 // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of
777 // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in
778 // corresponding calls to `sort3` with valid 3-item neighborhoods around each
779 // pointer, which in turn means the calls to `sort2` are done with valid
780 // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap`
781 // call.
782 let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
783 if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
784 ptr::swap(a, b);
785 swaps += 1;
786 }
787 };
788
789 // Swaps indices so that `v[a] <= v[b] <= v[c]`.
790 let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
791 sort2(a, b);
792 sort2(b, c);
793 sort2(a, b);
794 };
795
796 if len >= SHORTEST_MEDIAN_OF_MEDIANS {
797 // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`.
798 let mut sort_adjacent = |a: &mut usize| {
799 let tmp = *a;
800 sort3(&mut (tmp - 1), a, &mut (tmp + 1));
801 };
802
803 // Find medians in the neighborhoods of `a`, `b`, and `c`.
804 sort_adjacent(&mut a);
805 sort_adjacent(&mut b);
806 sort_adjacent(&mut c);
807 }
808
809 // Find the median among `a`, `b`, and `c`.
810 sort3(&mut a, &mut b, &mut c);
811 }
812
813 if swaps < MAX_SWAPS {
814 (b, swaps == 0)
815 } else {
816 // The maximum number of swaps was performed. Chances are the slice is descending or mostly
817 // descending, so reversing will probably help sort it faster.
818 v.reverse();
819 (len - 1 - b, true)
820 }
821}
822
823/// Sorts `v` recursively.
824///
825/// If the slice had a predecessor in the original array, it is specified as `pred`.
826///
827/// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero,
828/// this function will immediately switch to heapsort.
829fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &F, mut pred: Option<&'a mut T>, mut limit: u32)
830where
831 T: Send,
832 F: Fn(&T, &T) -> bool + Sync,
833{
834 // Slices of up to this length get sorted using insertion sort.
835 const MAX_INSERTION: usize = 20;
836
837 // If both partitions are up to this length, we continue sequentially. This number is as small
838 // as possible but so that the overhead of Rayon's task scheduling is still negligible.
839 const MAX_SEQUENTIAL: usize = 2000;
840
841 // True if the last partitioning was reasonably balanced.
842 let mut was_balanced = true;
843 // True if the last partitioning didn't shuffle elements (the slice was already partitioned).
844 let mut was_partitioned = true;
845
846 loop {
847 let len = v.len();
848
849 // Very short slices get sorted using insertion sort.
850 if len <= MAX_INSERTION {
851 if len >= 2 {
852 insertion_sort_shift_left(v, 1, is_less);
853 }
854 return;
855 }
856
857 // If too many bad pivot choices were made, simply fall back to heapsort in order to
858 // guarantee `O(n * log(n))` worst-case.
859 if limit == 0 {
860 heapsort(v, is_less);
861 return;
862 }
863
864 // If the last partitioning was imbalanced, try breaking patterns in the slice by shuffling
865 // some elements around. Hopefully we'll choose a better pivot this time.
866 if !was_balanced {
867 break_patterns(v);
868 limit -= 1;
869 }
870
871 // Choose a pivot and try guessing whether the slice is already sorted.
872 let (pivot, likely_sorted) = choose_pivot(v, is_less);
873
874 // If the last partitioning was decently balanced and didn't shuffle elements, and if pivot
875 // selection predicts the slice is likely already sorted...
876 if was_balanced && was_partitioned && likely_sorted {
877 // Try identifying several out-of-order elements and shifting them to correct
878 // positions. If the slice ends up being completely sorted, we're done.
879 if partial_insertion_sort(v, is_less) {
880 return;
881 }
882 }
883
884 // If the chosen pivot is equal to the predecessor, then it's the smallest element in the
885 // slice. Partition the slice into elements equal to and elements greater than the pivot.
886 // This case is usually hit when the slice contains many duplicate elements.
887 if let Some(&mut ref p) = pred {
888 if !is_less(p, &v[pivot]) {
889 let mid = partition_equal(v, pivot, is_less);
890
891 // Continue sorting elements greater than the pivot.
892 v = &mut v[mid..];
893 continue;
894 }
895 }
896
897 // Partition the slice.
898 let (mid, was_p) = partition(v, pivot, is_less);
899 was_balanced = cmp::min(mid, len - mid) >= len / 8;
900 was_partitioned = was_p;
901
902 // Split the slice into `left`, `pivot`, and `right`.
903 let (left, right) = v.split_at_mut(mid);
904 let (pivot, right) = right.split_at_mut(1);
905 let pivot = &mut pivot[0];
906
907 if Ord::max(left.len(), right.len()) <= MAX_SEQUENTIAL {
908 // Recurse into the shorter side only in order to minimize the total number of recursive
909 // calls and consume less stack space. Then just continue with the longer side (this is
910 // akin to tail recursion).
911 if left.len() < right.len() {
912 recurse(left, is_less, pred, limit);
913 v = right;
914 pred = Some(pivot);
915 } else {
916 recurse(right, is_less, Some(pivot), limit);
917 v = left;
918 }
919 } else {
920 // Sort the left and right half in parallel.
921 rayon_core::join(
922 || recurse(left, is_less, pred, limit),
923 || recurse(right, is_less, Some(pivot), limit),
924 );
925 break;
926 }
927 }
928}
929
930/// Sorts `v` using pattern-defeating quicksort in parallel.
931///
932/// The algorithm is unstable, in-place, and *O*(*n* \* log(*n*)) worst-case.
933pub(super) fn par_quicksort<T, F>(v: &mut [T], is_less: F)
934where
935 T: Send,
936 F: Fn(&T, &T) -> bool + Sync,
937{
938 // Sorting has no meaningful behavior on zero-sized types.
939 if size_of::<T>() == 0 {
940 return;
941 }
942
943 // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`.
944 let limit = usize::BITS - v.len().leading_zeros();
945
946 recurse(v, &is_less, None, limit);
947}
948
949/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
950/// stores the result into `v[..]`.
951///
952/// # Safety
953///
954/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
955/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
956unsafe fn merge<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &F)
957where
958 F: Fn(&T, &T) -> bool,
959{
960 let len = v.len();
961 let v = v.as_mut_ptr();
962
963 // SAFETY: mid and len must be in-bounds of v.
964 let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };
965
966 // The merge process first copies the shorter run into `buf`. Then it traces the newly copied
967 // run and the longer run forwards (or backwards), comparing their next unconsumed elements and
968 // copying the lesser (or greater) one into `v`.
969 //
970 // As soon as the shorter run is fully consumed, the process is done. If the longer run gets
971 // consumed first, then we must copy whatever is left of the shorter run into the remaining
972 // hole in `v`.
973 //
974 // Intermediate state of the process is always tracked by `hole`, which serves two purposes:
975 // 1. Protects integrity of `v` from panics in `is_less`.
976 // 2. Fills the remaining hole in `v` if the longer run gets consumed first.
977 //
978 // Panic safety:
979 //
980 // If `is_less` panics at any point during the process, `hole` will get dropped and fill the
981 // hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
982 // object it initially held exactly once.
983 let mut hole;
984
985 if mid <= len - mid {
986 // The left run is shorter.
987
988 // SAFETY: buf must have enough capacity for `v[..mid]`.
989 unsafe {
990 ptr::copy_nonoverlapping(v, buf, mid);
991 hole = MergeHole {
992 start: buf,
993 end: buf.add(mid),
994 dest: v,
995 };
996 }
997
998 // Initially, these pointers point to the beginnings of their arrays.
999 let left = &mut hole.start;
1000 let mut right = v_mid;
1001 let out = &mut hole.dest;
1002
1003 while *left < hole.end && right < v_end {
1004 // Consume the lesser side.
1005 // If equal, prefer the left run to maintain stability.
1006
1007 // SAFETY: left and right must be valid and part of v same for out.
1008 unsafe {
1009 let is_l = is_less(&*right, &**left);
1010 let to_copy = if is_l { right } else { *left };
1011 ptr::copy_nonoverlapping(to_copy, *out, 1);
1012 *out = out.add(1);
1013 right = right.add(is_l as usize);
1014 *left = left.add(!is_l as usize);
1015 }
1016 }
1017 } else {
1018 // The right run is shorter.
1019
1020 // SAFETY: buf must have enough capacity for `v[mid..]`.
1021 unsafe {
1022 ptr::copy_nonoverlapping(v_mid, buf, len - mid);
1023 hole = MergeHole {
1024 start: buf,
1025 end: buf.add(len - mid),
1026 dest: v_mid,
1027 };
1028 }
1029
1030 // Initially, these pointers point past the ends of their arrays.
1031 let left = &mut hole.dest;
1032 let right = &mut hole.end;
1033 let mut out = v_end;
1034
1035 while v < *left && buf < *right {
1036 // Consume the greater side.
1037 // If equal, prefer the right run to maintain stability.
1038
1039 // SAFETY: left and right must be valid and part of v same for out.
1040 unsafe {
1041 let is_l = is_less(&*right.sub(1), &*left.sub(1));
1042 *left = left.sub(is_l as usize);
1043 *right = right.sub(!is_l as usize);
1044 let to_copy = if is_l { *left } else { *right };
1045 out = out.sub(1);
1046 ptr::copy_nonoverlapping(to_copy, out, 1);
1047 }
1048 }
1049 }
1050 // Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
1051 // it will now be copied into the hole in `v`.
1052}
1053
1054// When dropped, copies the range `start..end` into `dest..`.
1055struct MergeHole<T> {
1056 start: *mut T,
1057 end: *mut T,
1058 dest: *mut T,
1059}
1060
1061impl<T> Drop for MergeHole<T> {
1062 fn drop(&mut self) {
1063 // SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
1064 unsafe {
1065 let len = self.end.offset_from(self.start) as usize;
1066 ptr::copy_nonoverlapping(self.start, self.dest, len);
1067 }
1068 }
1069}
1070
1071/// The result of merge sort.
1072#[must_use]
1073#[derive(Clone, Copy, PartialEq, Eq)]
1074enum MergeSortResult {
1075 /// The slice has already been sorted.
1076 NonDescending,
1077 /// The slice has been descending and therefore it was left intact.
1078 Descending,
1079 /// The slice was sorted.
1080 Sorted,
1081}
1082
1083/// This merge sort borrows some (but not all) ideas from TimSort, which used to be described in
1084/// detail [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt). However Python
1085/// has switched to a Powersort based implementation.
1086///
1087/// The algorithm identifies strictly descending and non-descending subsequences, which are called
1088/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed
1089/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are
1090/// satisfied:
1091///
1092/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len`
1093/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len`
1094///
1095/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case.
1096///
1097/// # Safety
1098///
1099/// The argument `buf` is used as a temporary buffer and must hold at least `v.len() / 2`.
1100unsafe fn merge_sort<T, CmpF>(v: &mut [T], buf_ptr: *mut T, is_less: &CmpF) -> MergeSortResult
1101where
1102 CmpF: Fn(&T, &T) -> bool,
1103{
1104 // The caller should have already checked that.
1105 debug_assert_ne!(size_of::<T>(), 0);
1106
1107 let len = v.len();
1108
1109 let mut runs = Vec::new();
1110
1111 let mut end = 0;
1112 let mut start = 0;
1113
1114 // Scan forward. Memory pre-fetching prefers forward scanning vs backwards scanning, and the
1115 // code-gen is usually better. For the most sensitive types such as integers, these are merged
1116 // bidirectionally at once. So there is no benefit in scanning backwards.
1117 while end < len {
1118 let (streak_end, was_reversed) = find_streak(&v[start..], is_less);
1119 end += streak_end;
1120 if start == 0 && end == len {
1121 return if was_reversed {
1122 MergeSortResult::Descending
1123 } else {
1124 MergeSortResult::NonDescending
1125 };
1126 }
1127 if was_reversed {
1128 v[start..end].reverse();
1129 }
1130
1131 // Insert some more elements into the run if it's too short. Insertion sort is faster than
1132 // merge sort on short sequences, so this significantly improves performance.
1133 end = provide_sorted_batch(v, start, end, is_less);
1134
1135 // Push this run onto the stack.
1136 runs.push(TimSortRun {
1137 start,
1138 len: end - start,
1139 });
1140 start = end;
1141
1142 // Merge some pairs of adjacent runs to satisfy the invariants.
1143 while let Some(r) = collapse(runs.as_slice(), len) {
1144 let left = runs[r];
1145 let right = runs[r + 1];
1146 let merge_slice = &mut v[left.start..right.start + right.len];
1147 // SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and
1148 // neither side may be on length 0.
1149 unsafe {
1150 merge(merge_slice, left.len, buf_ptr, is_less);
1151 }
1152 runs[r + 1] = TimSortRun {
1153 start: left.start,
1154 len: left.len + right.len,
1155 };
1156 runs.remove(r);
1157 }
1158 }
1159
1160 // Finally, exactly one run must remain in the stack.
1161 debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
1162
1163 // The original order of the slice was neither non-descending nor descending.
1164 return MergeSortResult::Sorted;
1165
1166 // Examines the stack of runs and identifies the next pair of runs to merge. More specifically,
1167 // if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the
1168 // algorithm should continue building a new run instead, `None` is returned.
1169 //
1170 // TimSort is infamous for its buggy implementations, as described here:
1171 // http://envisage-project.eu/timsort-specification-and-verification/
1172 //
1173 // The gist of the story is: we must enforce the invariants on the top four runs on the stack.
1174 // Enforcing them on just top three is not sufficient to ensure that the invariants will still
1175 // hold for *all* runs in the stack.
1176 //
1177 // This function correctly checks invariants for the top four runs. Additionally, if the top
1178 // run starts at index 0, it will always demand a merge operation until the stack is fully
1179 // collapsed, in order to complete the sort.
1180 #[inline]
1181 fn collapse(runs: &[TimSortRun], stop: usize) -> Option<usize> {
1182 let n = runs.len();
1183 if n >= 2
1184 && (runs[n - 1].start + runs[n - 1].len == stop
1185 || runs[n - 2].len <= runs[n - 1].len
1186 || (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
1187 || (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
1188 {
1189 if n >= 3 && runs[n - 3].len < runs[n - 1].len {
1190 Some(n - 3)
1191 } else {
1192 Some(n - 2)
1193 }
1194 } else {
1195 None
1196 }
1197 }
1198}
1199
1200/// Internal type used by merge_sort.
1201#[derive(Clone, Copy, Debug)]
1202struct TimSortRun {
1203 len: usize,
1204 start: usize,
1205}
1206
1207/// Takes a range as denoted by start and end, that is already sorted and extends it to the right if
1208/// necessary with sorts optimized for smaller ranges such as insertion sort.
1209fn provide_sorted_batch<T, F>(v: &mut [T], start: usize, mut end: usize, is_less: &F) -> usize
1210where
1211 F: Fn(&T, &T) -> bool,
1212{
1213 let len = v.len();
1214 assert!(end >= start && end <= len);
1215
1216 // This value is a balance between least comparisons and best performance, as
1217 // influenced by for example cache locality.
1218 const MIN_INSERTION_RUN: usize = 10;
1219
1220 // Insert some more elements into the run if it's too short. Insertion sort is faster than
1221 // merge sort on short sequences, so this significantly improves performance.
1222 let start_end_diff = end - start;
1223
1224 if start_end_diff < MIN_INSERTION_RUN && end < len {
1225 // v[start_found..end] are elements that are already sorted in the input. We want to extend
1226 // the sorted region to the left, so we push up MIN_INSERTION_RUN - 1 to the right. Which is
1227 // more efficient that trying to push those already sorted elements to the left.
1228 end = cmp::min(start + MIN_INSERTION_RUN, len);
1229 let presorted_start = cmp::max(start_end_diff, 1);
1230
1231 insertion_sort_shift_left(&mut v[start..end], presorted_start, is_less);
1232 }
1233
1234 end
1235}
1236
1237/// Finds a streak of presorted elements starting at the beginning of the slice. Returns the first
1238/// value that is not part of said streak, and a bool denoting whether the streak was reversed.
1239/// Streaks can be increasing or decreasing.
1240fn find_streak<T, F>(v: &[T], is_less: &F) -> (usize, bool)
1241where
1242 F: Fn(&T, &T) -> bool,
1243{
1244 let len = v.len();
1245
1246 if len < 2 {
1247 return (len, false);
1248 }
1249
1250 let mut end = 2;
1251
1252 // SAFETY: See below specific.
1253 unsafe {
1254 // SAFETY: We checked that len >= 2, so 0 and 1 are valid indices.
1255 let assume_reverse = is_less(v.get_unchecked(1), v.get_unchecked(0));
1256
1257 // SAFETY: We know end >= 2 and check end < len.
1258 // From that follows that accessing v at end and end - 1 is safe.
1259 if assume_reverse {
1260 while end < len && is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) {
1261 end += 1;
1262 }
1263
1264 (end, true)
1265 } else {
1266 while end < len && !is_less(v.get_unchecked(end), v.get_unchecked(end - 1)) {
1267 end += 1;
1268 }
1269 (end, false)
1270 }
1271 }
1272}
1273
1274////////////////////////////////////////////////////////////////////////////
1275// Everything above this line is copied from `core::slice::sort` (with very minor tweaks).
1276// Everything below this line is custom parallelization for rayon.
1277////////////////////////////////////////////////////////////////////////////
1278
1279/// Splits two sorted slices so that they can be merged in parallel.
1280///
1281/// Returns two indices `(a, b)` so that slices `left[..a]` and `right[..b]` come before
1282/// `left[a..]` and `right[b..]`.
1283fn split_for_merge<T, F>(left: &[T], right: &[T], is_less: &F) -> (usize, usize)
1284where
1285 F: Fn(&T, &T) -> bool,
1286{
1287 let left_len = left.len();
1288 let right_len = right.len();
1289
1290 if left_len >= right_len {
1291 let left_mid = left_len / 2;
1292
1293 // Find the first element in `right` that is greater than or equal to `left[left_mid]`.
1294 let mut a = 0;
1295 let mut b = right_len;
1296 while a < b {
1297 let m = a + (b - a) / 2;
1298 if is_less(&right[m], &left[left_mid]) {
1299 a = m + 1;
1300 } else {
1301 b = m;
1302 }
1303 }
1304
1305 (left_mid, a)
1306 } else {
1307 let right_mid = right_len / 2;
1308
1309 // Find the first element in `left` that is greater than `right[right_mid]`.
1310 let mut a = 0;
1311 let mut b = left_len;
1312 while a < b {
1313 let m = a + (b - a) / 2;
1314 if is_less(&right[right_mid], &left[m]) {
1315 b = m;
1316 } else {
1317 a = m + 1;
1318 }
1319 }
1320
1321 (a, right_mid)
1322 }
1323}
1324
1325/// Merges slices `left` and `right` in parallel and stores the result into `dest`.
1326///
1327/// # Safety
1328///
1329/// The `dest` pointer must have enough space to store the result.
1330///
1331/// Even if `is_less` panics at any point during the merge process, this function will fully copy
1332/// all elements from `left` and `right` into `dest` (not necessarily in sorted order).
1333unsafe fn par_merge<T, F>(left: &mut [T], right: &mut [T], dest: *mut T, is_less: &F)
1334where
1335 T: Send,
1336 F: Fn(&T, &T) -> bool + Sync,
1337{
1338 // Slices whose lengths sum up to this value are merged sequentially. This number is slightly
1339 // larger than `CHUNK_LENGTH`, and the reason is that merging is faster than merge sorting, so
1340 // merging needs a bit coarser granularity in order to hide the overhead of Rayon's task
1341 // scheduling.
1342 const MAX_SEQUENTIAL: usize = 5000;
1343
1344 let left_len = left.len();
1345 let right_len = right.len();
1346
1347 // Intermediate state of the merge process, which serves two purposes:
1348 // 1. Protects integrity of `dest` from panics in `is_less`.
1349 // 2. Copies the remaining elements as soon as one of the two sides is exhausted.
1350 //
1351 // Panic safety:
1352 //
1353 // If `is_less` panics at any point during the merge process, `s` will get dropped and copy the
1354 // remaining parts of `left` and `right` into `dest`.
1355 let mut s = State {
1356 left_start: left.as_mut_ptr(),
1357 left_end: left.as_mut_ptr().add(left_len),
1358 right_start: right.as_mut_ptr(),
1359 right_end: right.as_mut_ptr().add(right_len),
1360 dest,
1361 };
1362
1363 if left_len == 0 || right_len == 0 || left_len + right_len < MAX_SEQUENTIAL {
1364 while s.left_start < s.left_end && s.right_start < s.right_end {
1365 // Consume the lesser side.
1366 // If equal, prefer the left run to maintain stability.
1367 let is_l = is_less(&*s.right_start, &*s.left_start);
1368 let to_copy = if is_l { s.right_start } else { s.left_start };
1369 ptr::copy_nonoverlapping(to_copy, s.dest, 1);
1370 s.dest = s.dest.add(1);
1371 s.right_start = s.right_start.add(is_l as usize);
1372 s.left_start = s.left_start.add(!is_l as usize);
1373 }
1374 } else {
1375 // Function `split_for_merge` might panic. If that happens, `s` will get destructed and copy
1376 // the whole `left` and `right` into `dest`.
1377 let (left_mid, right_mid) = split_for_merge(left, right, is_less);
1378 let (left_l, left_r) = left.split_at_mut(left_mid);
1379 let (right_l, right_r) = right.split_at_mut(right_mid);
1380
1381 // Prevent the destructor of `s` from running. Rayon will ensure that both calls to
1382 // `par_merge` happen. If one of the two calls panics, they will ensure that elements still
1383 // get copied into `dest_left` and `dest_right``.
1384 mem::forget(s);
1385
1386 // Wrap pointers in SendPtr so that they can be sent to another thread
1387 // See the documentation of SendPtr for a full explanation
1388 let dest_l = SendPtr(dest);
1389 let dest_r = SendPtr(dest.add(left_l.len() + right_l.len()));
1390 rayon_core::join(
1391 move || par_merge(left_l, right_l, dest_l.get(), is_less),
1392 move || par_merge(left_r, right_r, dest_r.get(), is_less),
1393 );
1394 }
1395 // Finally, `s` gets dropped if we used sequential merge, thus copying the remaining elements
1396 // all at once.
1397
1398 // When dropped, copies arrays `left_start..left_end` and `right_start..right_end` into `dest`,
1399 // in that order.
1400 struct State<T> {
1401 left_start: *mut T,
1402 left_end: *mut T,
1403 right_start: *mut T,
1404 right_end: *mut T,
1405 dest: *mut T,
1406 }
1407
1408 impl<T> Drop for State<T> {
1409 fn drop(&mut self) {
1410 // Copy array `left`, followed by `right`.
1411 unsafe {
1412 let left_len = self.left_end.offset_from(self.left_start) as usize;
1413 ptr::copy_nonoverlapping(self.left_start, self.dest, left_len);
1414 self.dest = self.dest.add(left_len);
1415
1416 let right_len = self.right_end.offset_from(self.right_start) as usize;
1417 ptr::copy_nonoverlapping(self.right_start, self.dest, right_len);
1418 }
1419 }
1420 }
1421}
1422
1423/// Recursively merges pre-sorted chunks inside `v`.
1424///
1425/// Chunks of `v` are stored in `chunks` as intervals (inclusive left and exclusive right bound).
1426/// Argument `buf` is an auxiliary buffer that will be used during the procedure.
1427/// If `into_buf` is true, the result will be stored into `buf`, otherwise it will be in `v`.
1428///
1429/// # Safety
1430///
1431/// The number of chunks must be positive and they must be adjacent: the right bound of each chunk
1432/// must equal the left bound of the following chunk.
1433///
1434/// The buffer must be at least as long as `v`.
1435unsafe fn merge_recurse<T, F>(
1436 v: *mut T,
1437 buf: *mut T,
1438 chunks: &[(usize, usize)],
1439 into_buf: bool,
1440 is_less: &F,
1441) where
1442 T: Send,
1443 F: Fn(&T, &T) -> bool + Sync,
1444{
1445 let len = chunks.len();
1446 debug_assert!(len > 0);
1447
1448 // Base case of the algorithm.
1449 // If only one chunk is remaining, there's no more work to split and merge.
1450 if len == 1 {
1451 if into_buf {
1452 // Copy the chunk from `v` into `buf`.
1453 let (start, end) = chunks[0];
1454 let src = v.add(start);
1455 let dest = buf.add(start);
1456 ptr::copy_nonoverlapping(src, dest, end - start);
1457 }
1458 return;
1459 }
1460
1461 // Split the chunks into two halves.
1462 let (start, _) = chunks[0];
1463 let (mid, _) = chunks[len / 2];
1464 let (_, end) = chunks[len - 1];
1465 let (left, right) = chunks.split_at(len / 2);
1466
1467 // After recursive calls finish we'll have to merge chunks `(start, mid)` and `(mid, end)` from
1468 // `src` into `dest`. If the current invocation has to store the result into `buf`, we'll
1469 // merge chunks from `v` into `buf`, and vice versa.
1470 //
1471 // Recursive calls flip `into_buf` at each level of recursion. More concretely, `par_merge`
1472 // merges chunks from `buf` into `v` at the first level, from `v` into `buf` at the second
1473 // level etc.
1474 let (src, dest) = if into_buf { (v, buf) } else { (buf, v) };
1475
1476 // Panic safety:
1477 //
1478 // If `is_less` panics at any point during the recursive calls, the destructor of `guard` will
1479 // be executed, thus copying everything from `src` into `dest`. This way we ensure that all
1480 // chunks are in fact copied into `dest`, even if the merge process doesn't finish.
1481 let guard = MergeHole {
1482 start: src.add(start),
1483 end: src.add(end),
1484 dest: dest.add(start),
1485 };
1486
1487 // Wrap pointers in SendPtr so that they can be sent to another thread
1488 // See the documentation of SendPtr for a full explanation
1489 let v = SendPtr(v);
1490 let buf = SendPtr(buf);
1491 rayon_core::join(
1492 move || merge_recurse(v.get(), buf.get(), left, !into_buf, is_less),
1493 move || merge_recurse(v.get(), buf.get(), right, !into_buf, is_less),
1494 );
1495
1496 // Everything went all right - recursive calls didn't panic.
1497 // Forget the guard in order to prevent its destructor from running.
1498 mem::forget(guard);
1499
1500 // Merge chunks `(start, mid)` and `(mid, end)` from `src` into `dest`.
1501 let src_left = slice::from_raw_parts_mut(src.add(start), mid - start);
1502 let src_right = slice::from_raw_parts_mut(src.add(mid), end - mid);
1503 par_merge(src_left, src_right, dest.add(start), is_less);
1504}
1505
1506/// Sorts `v` using merge sort in parallel.
1507///
1508/// The algorithm is stable, allocates memory, and `O(n log n)` worst-case.
1509/// The allocated temporary buffer is of the same length as is `v`.
1510pub(super) fn par_mergesort<T, F>(v: &mut [T], is_less: F)
1511where
1512 T: Send,
1513 F: Fn(&T, &T) -> bool + Sync,
1514{
1515 // Slices of up to this length get sorted using insertion sort in order to avoid the cost of
1516 // buffer allocation.
1517 const MAX_INSERTION: usize = 20;
1518
1519 // The length of initial chunks. This number is as small as possible but so that the overhead
1520 // of Rayon's task scheduling is still negligible.
1521 const CHUNK_LENGTH: usize = 2000;
1522
1523 // Sorting has no meaningful behavior on zero-sized types.
1524 if size_of::<T>() == 0 {
1525 return;
1526 }
1527
1528 let len = v.len();
1529
1530 // Short slices get sorted in-place via insertion sort to avoid allocations.
1531 if len <= MAX_INSERTION {
1532 if len >= 2 {
1533 insertion_sort_shift_left(v, 1, &is_less);
1534 }
1535 return;
1536 }
1537
1538 // Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it
1539 // shallow copies of the contents of `v` without risking the dtors running on copies if
1540 // `is_less` panics.
1541 let mut buf = Vec::<T>::with_capacity(len);
1542 let buf = buf.as_mut_ptr();
1543
1544 // If the slice is not longer than one chunk would be, do sequential merge sort and return.
1545 if len <= CHUNK_LENGTH {
1546 let res = unsafe { merge_sort(v, buf, &is_less) };
1547 if res == MergeSortResult::Descending {
1548 v.reverse();
1549 }
1550 return;
1551 }
1552
1553 // Split the slice into chunks and merge sort them in parallel.
1554 // However, descending chunks will not be sorted - they will be simply left intact.
1555 let mut iter = {
1556 // Wrap pointer in SendPtr so that it can be sent to another thread
1557 // See the documentation of SendPtr for a full explanation
1558 let buf = SendPtr(buf);
1559 let is_less = &is_less;
1560
1561 v.par_chunks_mut(CHUNK_LENGTH)
1562 .with_max_len(1)
1563 .enumerate()
1564 .map(move |(i, chunk)| {
1565 let l = CHUNK_LENGTH * i;
1566 let r = l + chunk.len();
1567 unsafe {
1568 let buf = buf.get().add(l);
1569 (l, r, merge_sort(chunk, buf, is_less))
1570 }
1571 })
1572 .collect::<Vec<_>>()
1573 .into_iter()
1574 .peekable()
1575 };
1576
1577 // Now attempt to concatenate adjacent chunks that were left intact.
1578 let mut chunks = Vec::with_capacity(iter.len());
1579
1580 while let Some((a, mut b, res)) = iter.next() {
1581 // If this chunk was not modified by the sort procedure...
1582 if res != MergeSortResult::Sorted {
1583 while let Some(&(x, y, r)) = iter.peek() {
1584 // If the following chunk is of the same type and can be concatenated...
1585 if r == res && (r == MergeSortResult::Descending) == is_less(&v[x], &v[x - 1]) {
1586 // Concatenate them.
1587 b = y;
1588 iter.next();
1589 } else {
1590 break;
1591 }
1592 }
1593 }
1594
1595 // Descending chunks must be reversed.
1596 if res == MergeSortResult::Descending {
1597 v[a..b].reverse();
1598 }
1599
1600 chunks.push((a, b));
1601 }
1602
1603 // All chunks are properly sorted.
1604 // Now we just have to merge them together.
1605 unsafe {
1606 merge_recurse(v.as_mut_ptr(), buf, &chunks, false, &is_less);
1607 }
1608}
1609
1610#[cfg(test)]
1611mod tests {
1612 use super::heapsort;
1613 use super::split_for_merge;
1614 use rand::distr::Uniform;
1615 use rand::{rng, Rng};
1616
1617 #[test]
1618 fn test_heapsort() {
1619 let rng = &mut rng();
1620
1621 for len in (0..25).chain(500..501) {
1622 for &modulus in &[5, 10, 100] {
1623 let dist = Uniform::new(0, modulus).unwrap();
1624 for _ in 0..100 {
1625 let v: Vec<i32> = rng.sample_iter(&dist).take(len).collect();
1626
1627 // Test heapsort using `<` operator.
1628 let mut tmp = v.clone();
1629 heapsort(&mut tmp, |a, b| a < b);
1630 assert!(tmp.windows(2).all(|w| w[0] <= w[1]));
1631
1632 // Test heapsort using `>` operator.
1633 let mut tmp = v.clone();
1634 heapsort(&mut tmp, |a, b| a > b);
1635 assert!(tmp.windows(2).all(|w| w[0] >= w[1]));
1636 }
1637 }
1638 }
1639
1640 // Sort using a completely random comparison function.
1641 // This will reorder the elements *somehow*, but won't panic.
1642 let mut v: Vec<_> = (0..100).collect();
1643 heapsort(&mut v, |_, _| rand::rng().random());
1644 heapsort(&mut v, |a, b| a < b);
1645
1646 for (i, &entry) in v.iter().enumerate() {
1647 assert_eq!(entry, i);
1648 }
1649 }
1650
1651 #[test]
1652 fn test_split_for_merge() {
1653 fn check(left: &[u32], right: &[u32]) {
1654 let (l, r) = split_for_merge(left, right, &|&a, &b| a < b);
1655 assert!(left[..l]
1656 .iter()
1657 .all(|&x| right[r..].iter().all(|&y| x <= y)));
1658 assert!(right[..r].iter().all(|&x| left[l..].iter().all(|&y| x < y)));
1659 }
1660
1661 check(&[1, 2, 2, 2, 2, 3], &[1, 2, 2, 2, 2, 3]);
1662 check(&[1, 2, 2, 2, 2, 3], &[]);
1663 check(&[], &[1, 2, 2, 2, 2, 3]);
1664
1665 let rng = &mut rng();
1666
1667 for _ in 0..100 {
1668 let limit: u32 = rng.random_range(1..21);
1669 let left_len: usize = rng.random_range(0..20);
1670 let right_len: usize = rng.random_range(0..20);
1671
1672 let mut left = rng
1673 .sample_iter(&Uniform::new(0, limit).unwrap())
1674 .take(left_len)
1675 .collect::<Vec<_>>();
1676 let mut right = rng
1677 .sample_iter(&Uniform::new(0, limit).unwrap())
1678 .take(right_len)
1679 .collect::<Vec<_>>();
1680
1681 left.sort();
1682 right.sort();
1683 check(&left, &right);
1684 }
1685 }
1686}