rayon/iter/
fold_chunks.rs

1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6
7/// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a
8/// function over them, producing a single value for each group.
9///
10/// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`]
11///
12/// [`fold_chunks()`]: IndexedParallelIterator::fold_chunks()
13#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
14#[derive(Clone)]
15pub struct FoldChunks<I, ID, F> {
16    base: I,
17    chunk_size: usize,
18    fold_op: F,
19    identity: ID,
20}
21
22impl<I: Debug, ID, F> Debug for FoldChunks<I, ID, F> {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("Fold")
25            .field("base", &self.base)
26            .field("chunk_size", &self.chunk_size)
27            .finish()
28    }
29}
30
31impl<I, ID, F> FoldChunks<I, ID, F> {
32    /// Creates a new `FoldChunks` iterator
33    pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
34        FoldChunks {
35            base,
36            chunk_size,
37            identity,
38            fold_op,
39        }
40    }
41}
42
43impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
44where
45    I: IndexedParallelIterator,
46    ID: Fn() -> U + Send + Sync,
47    F: Fn(U, I::Item) -> U + Send + Sync,
48    U: Send,
49{
50    type Item = U;
51
52    fn drive_unindexed<C>(self, consumer: C) -> C::Result
53    where
54        C: Consumer<U>,
55    {
56        bridge(self, consumer)
57    }
58
59    fn opt_len(&self) -> Option<usize> {
60        Some(self.len())
61    }
62}
63
64impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
65where
66    I: IndexedParallelIterator,
67    ID: Fn() -> U + Send + Sync,
68    F: Fn(U, I::Item) -> U + Send + Sync,
69    U: Send,
70{
71    fn len(&self) -> usize {
72        self.base.len().div_ceil(self.chunk_size)
73    }
74
75    fn drive<C>(self, consumer: C) -> C::Result
76    where
77        C: Consumer<Self::Item>,
78    {
79        bridge(self, consumer)
80    }
81
82    fn with_producer<CB>(self, callback: CB) -> CB::Output
83    where
84        CB: ProducerCallback<Self::Item>,
85    {
86        let len = self.base.len();
87        return self.base.with_producer(Callback {
88            chunk_size: self.chunk_size,
89            len,
90            identity: self.identity,
91            fold_op: self.fold_op,
92            callback,
93        });
94
95        struct Callback<CB, ID, F> {
96            chunk_size: usize,
97            len: usize,
98            identity: ID,
99            fold_op: F,
100            callback: CB,
101        }
102
103        impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
104        where
105            CB: ProducerCallback<U>,
106            ID: Fn() -> U + Send + Sync,
107            F: Fn(U, T) -> U + Send + Sync,
108        {
109            type Output = CB::Output;
110
111            fn callback<P>(self, base: P) -> CB::Output
112            where
113                P: Producer<Item = T>,
114            {
115                let identity = &self.identity;
116                let fold_op = &self.fold_op;
117                let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
118                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
119                self.callback.callback(producer)
120            }
121        }
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use super::*;
128    use std::ops::Add;
129
130    #[test]
131    fn check_fold_chunks() {
132        let words = "bishbashbosh!"
133            .chars()
134            .collect::<Vec<_>>()
135            .into_par_iter()
136            .fold_chunks(4, String::new, |mut s, c| {
137                s.push(c);
138                s
139            })
140            .collect::<Vec<_>>();
141
142        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
143    }
144
145    // 'closure' values for tests below
146    fn id() -> i32 {
147        0
148    }
149    fn sum<T, U>(x: T, y: U) -> T
150    where
151        T: Add<U, Output = T>,
152    {
153        x + y
154    }
155
156    #[test]
157    #[should_panic(expected = "chunk_size must not be zero")]
158    fn check_fold_chunks_zero_size() {
159        let _: Vec<i32> = vec![1, 2, 3]
160            .into_par_iter()
161            .fold_chunks(0, id, sum)
162            .collect();
163    }
164
165    #[test]
166    fn check_fold_chunks_even_size() {
167        assert_eq!(
168            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
169            (1..10)
170                .into_par_iter()
171                .fold_chunks(3, id, sum)
172                .collect::<Vec<i32>>()
173        );
174    }
175
176    #[test]
177    fn check_fold_chunks_empty() {
178        let v: Vec<i32> = vec![];
179        let expected: Vec<i32> = vec![];
180        assert_eq!(
181            expected,
182            v.into_par_iter()
183                .fold_chunks(2, id, sum)
184                .collect::<Vec<i32>>()
185        );
186    }
187
188    #[test]
189    fn check_fold_chunks_len() {
190        assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
191        assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
192        assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
193        assert_eq!(1, [1].par_iter().fold_chunks(3, id, sum).len());
194        assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
195    }
196
197    #[test]
198    fn check_fold_chunks_uneven() {
199        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
200            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
201            (vec![1], 5, vec![1]),
202            ((0..4).collect(), 3, vec![1 + 2, 3]),
203        ];
204
205        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
206            let mut res: Vec<u32> = vec![];
207            v.par_iter()
208                .fold_chunks(n, || 0, sum)
209                .collect_into_vec(&mut res);
210            assert_eq!(expected, res, "Case {i} failed");
211
212            res.truncate(0);
213            v.into_par_iter()
214                .fold_chunks(n, || 0, sum)
215                .rev()
216                .collect_into_vec(&mut res);
217            assert_eq!(
218                expected.into_iter().rev().collect::<Vec<u32>>(),
219                res,
220                "Case {i} reversed failed"
221            );
222        }
223    }
224}