rayon/iter/
fold_chunks_with.rs

1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6
7/// `FoldChunksWith` is an iterator that groups elements of an underlying iterator and applies a
8/// function over them, producing a single value for each group.
9///
10/// This struct is created by the [`fold_chunks_with()`] method on [`IndexedParallelIterator`]
11///
12/// [`fold_chunks_with()`]: IndexedParallelIterator::fold_chunks()
13#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
14#[derive(Clone)]
15pub struct FoldChunksWith<I, U, F> {
16    base: I,
17    chunk_size: usize,
18    item: U,
19    fold_op: F,
20}
21
22impl<I: Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("Fold")
25            .field("base", &self.base)
26            .field("chunk_size", &self.chunk_size)
27            .field("item", &self.item)
28            .finish()
29    }
30}
31
32impl<I, U, F> FoldChunksWith<I, U, F> {
33    /// Creates a new `FoldChunksWith` iterator
34    pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
35        FoldChunksWith {
36            base,
37            chunk_size,
38            item,
39            fold_op,
40        }
41    }
42}
43
44impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
45where
46    I: IndexedParallelIterator,
47    U: Send + Clone,
48    F: Fn(U, I::Item) -> U + Send + Sync,
49{
50    type Item = U;
51
52    fn drive_unindexed<C>(self, consumer: C) -> C::Result
53    where
54        C: Consumer<U>,
55    {
56        bridge(self, consumer)
57    }
58
59    fn opt_len(&self) -> Option<usize> {
60        Some(self.len())
61    }
62}
63
64impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
65where
66    I: IndexedParallelIterator,
67    U: Send + Clone,
68    F: Fn(U, I::Item) -> U + Send + Sync,
69{
70    fn len(&self) -> usize {
71        self.base.len().div_ceil(self.chunk_size)
72    }
73
74    fn drive<C>(self, consumer: C) -> C::Result
75    where
76        C: Consumer<Self::Item>,
77    {
78        bridge(self, consumer)
79    }
80
81    fn with_producer<CB>(self, callback: CB) -> CB::Output
82    where
83        CB: ProducerCallback<Self::Item>,
84    {
85        let len = self.base.len();
86        return self.base.with_producer(Callback {
87            chunk_size: self.chunk_size,
88            len,
89            item: self.item,
90            fold_op: self.fold_op,
91            callback,
92        });
93
94        struct Callback<CB, T, F> {
95            chunk_size: usize,
96            len: usize,
97            item: T,
98            fold_op: F,
99            callback: CB,
100        }
101
102        impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
103        where
104            CB: ProducerCallback<U>,
105            U: Send + Clone,
106            F: Fn(U, T) -> U + Send + Sync,
107        {
108            type Output = CB::Output;
109
110            fn callback<P>(self, base: P) -> CB::Output
111            where
112                P: Producer<Item = T>,
113            {
114                let item = self.item;
115                let fold_op = &self.fold_op;
116                let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
117                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
118                self.callback.callback(producer)
119            }
120        }
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127    use std::ops::Add;
128
129    #[test]
130    fn check_fold_chunks_with() {
131        let words = "bishbashbosh!"
132            .chars()
133            .collect::<Vec<_>>()
134            .into_par_iter()
135            .fold_chunks_with(4, String::new(), |mut s, c| {
136                s.push(c);
137                s
138            })
139            .collect::<Vec<_>>();
140
141        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
142    }
143
144    // 'closure' value for tests below
145    fn sum<T, U>(x: T, y: U) -> T
146    where
147        T: Add<U, Output = T>,
148    {
149        x + y
150    }
151
152    #[test]
153    #[should_panic(expected = "chunk_size must not be zero")]
154    fn check_fold_chunks_zero_size() {
155        let _: Vec<i32> = vec![1, 2, 3]
156            .into_par_iter()
157            .fold_chunks_with(0, 0, sum)
158            .collect();
159    }
160
161    #[test]
162    fn check_fold_chunks_even_size() {
163        assert_eq!(
164            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
165            (1..10)
166                .into_par_iter()
167                .fold_chunks_with(3, 0, sum)
168                .collect::<Vec<i32>>()
169        );
170    }
171
172    #[test]
173    fn check_fold_chunks_with_empty() {
174        let v: Vec<i32> = vec![];
175        let expected: Vec<i32> = vec![];
176        assert_eq!(
177            expected,
178            v.into_par_iter()
179                .fold_chunks_with(2, 0, sum)
180                .collect::<Vec<i32>>()
181        );
182    }
183
184    #[test]
185    fn check_fold_chunks_len() {
186        assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
187        assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
188        assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
189        assert_eq!(1, [1].par_iter().fold_chunks_with(3, 0, sum).len());
190        assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
191    }
192
193    #[test]
194    fn check_fold_chunks_uneven() {
195        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
196            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
197            (vec![1], 5, vec![1]),
198            ((0..4).collect(), 3, vec![1 + 2, 3]),
199        ];
200
201        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
202            let mut res: Vec<u32> = vec![];
203            v.par_iter()
204                .fold_chunks_with(n, 0, sum)
205                .collect_into_vec(&mut res);
206            assert_eq!(expected, res, "Case {i} failed");
207
208            res.truncate(0);
209            v.into_par_iter()
210                .fold_chunks_with(n, 0, sum)
211                .rev()
212                .collect_into_vec(&mut res);
213            assert_eq!(
214                expected.into_iter().rev().collect::<Vec<u32>>(),
215                res,
216                "Case {i} reversed failed"
217            );
218        }
219    }
220}