1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6
7#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
14#[derive(Clone)]
15pub struct FoldChunks<I, ID, F> {
16 base: I,
17 chunk_size: usize,
18 fold_op: F,
19 identity: ID,
20}
21
22impl<I: Debug, ID, F> Debug for FoldChunks<I, ID, F> {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_struct("Fold")
25 .field("base", &self.base)
26 .field("chunk_size", &self.chunk_size)
27 .finish()
28 }
29}
30
31impl<I, ID, F> FoldChunks<I, ID, F> {
32 pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
34 FoldChunks {
35 base,
36 chunk_size,
37 identity,
38 fold_op,
39 }
40 }
41}
42
43impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
44where
45 I: IndexedParallelIterator,
46 ID: Fn() -> U + Send + Sync,
47 F: Fn(U, I::Item) -> U + Send + Sync,
48 U: Send,
49{
50 type Item = U;
51
52 fn drive_unindexed<C>(self, consumer: C) -> C::Result
53 where
54 C: Consumer<U>,
55 {
56 bridge(self, consumer)
57 }
58
59 fn opt_len(&self) -> Option<usize> {
60 Some(self.len())
61 }
62}
63
64impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
65where
66 I: IndexedParallelIterator,
67 ID: Fn() -> U + Send + Sync,
68 F: Fn(U, I::Item) -> U + Send + Sync,
69 U: Send,
70{
71 fn len(&self) -> usize {
72 self.base.len().div_ceil(self.chunk_size)
73 }
74
75 fn drive<C>(self, consumer: C) -> C::Result
76 where
77 C: Consumer<Self::Item>,
78 {
79 bridge(self, consumer)
80 }
81
82 fn with_producer<CB>(self, callback: CB) -> CB::Output
83 where
84 CB: ProducerCallback<Self::Item>,
85 {
86 let len = self.base.len();
87 return self.base.with_producer(Callback {
88 chunk_size: self.chunk_size,
89 len,
90 identity: self.identity,
91 fold_op: self.fold_op,
92 callback,
93 });
94
95 struct Callback<CB, ID, F> {
96 chunk_size: usize,
97 len: usize,
98 identity: ID,
99 fold_op: F,
100 callback: CB,
101 }
102
103 impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
104 where
105 CB: ProducerCallback<U>,
106 ID: Fn() -> U + Send + Sync,
107 F: Fn(U, T) -> U + Send + Sync,
108 {
109 type Output = CB::Output;
110
111 fn callback<P>(self, base: P) -> CB::Output
112 where
113 P: Producer<Item = T>,
114 {
115 let identity = &self.identity;
116 let fold_op = &self.fold_op;
117 let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
118 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
119 self.callback.callback(producer)
120 }
121 }
122 }
123}
124
125#[cfg(test)]
126mod test {
127 use super::*;
128 use std::ops::Add;
129
130 #[test]
131 fn check_fold_chunks() {
132 let words = "bishbashbosh!"
133 .chars()
134 .collect::<Vec<_>>()
135 .into_par_iter()
136 .fold_chunks(4, String::new, |mut s, c| {
137 s.push(c);
138 s
139 })
140 .collect::<Vec<_>>();
141
142 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
143 }
144
145 fn id() -> i32 {
147 0
148 }
149 fn sum<T, U>(x: T, y: U) -> T
150 where
151 T: Add<U, Output = T>,
152 {
153 x + y
154 }
155
156 #[test]
157 #[should_panic(expected = "chunk_size must not be zero")]
158 fn check_fold_chunks_zero_size() {
159 let _: Vec<i32> = vec![1, 2, 3]
160 .into_par_iter()
161 .fold_chunks(0, id, sum)
162 .collect();
163 }
164
165 #[test]
166 fn check_fold_chunks_even_size() {
167 assert_eq!(
168 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
169 (1..10)
170 .into_par_iter()
171 .fold_chunks(3, id, sum)
172 .collect::<Vec<i32>>()
173 );
174 }
175
176 #[test]
177 fn check_fold_chunks_empty() {
178 let v: Vec<i32> = vec![];
179 let expected: Vec<i32> = vec![];
180 assert_eq!(
181 expected,
182 v.into_par_iter()
183 .fold_chunks(2, id, sum)
184 .collect::<Vec<i32>>()
185 );
186 }
187
188 #[test]
189 fn check_fold_chunks_len() {
190 assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
191 assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
192 assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
193 assert_eq!(1, [1].par_iter().fold_chunks(3, id, sum).len());
194 assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
195 }
196
197 #[test]
198 fn check_fold_chunks_uneven() {
199 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
200 ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
201 (vec![1], 5, vec![1]),
202 ((0..4).collect(), 3, vec![1 + 2, 3]),
203 ];
204
205 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
206 let mut res: Vec<u32> = vec![];
207 v.par_iter()
208 .fold_chunks(n, || 0, sum)
209 .collect_into_vec(&mut res);
210 assert_eq!(expected, res, "Case {i} failed");
211
212 res.truncate(0);
213 v.into_par_iter()
214 .fold_chunks(n, || 0, sum)
215 .rev()
216 .collect_into_vec(&mut res);
217 assert_eq!(
218 expected.into_iter().rev().collect::<Vec<u32>>(),
219 res,
220 "Case {i} reversed failed"
221 );
222 }
223 }
224}