1use std::fmt::{self, Debug};
2
3use super::chunks::ChunkProducer;
4use super::plumbing::*;
5use super::*;
6
7#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
14#[derive(Clone)]
15pub struct FoldChunksWith<I, U, F> {
16 base: I,
17 chunk_size: usize,
18 item: U,
19 fold_op: F,
20}
21
22impl<I: Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_struct("Fold")
25 .field("base", &self.base)
26 .field("chunk_size", &self.chunk_size)
27 .field("item", &self.item)
28 .finish()
29 }
30}
31
32impl<I, U, F> FoldChunksWith<I, U, F> {
33 pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
35 FoldChunksWith {
36 base,
37 chunk_size,
38 item,
39 fold_op,
40 }
41 }
42}
43
44impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
45where
46 I: IndexedParallelIterator,
47 U: Send + Clone,
48 F: Fn(U, I::Item) -> U + Send + Sync,
49{
50 type Item = U;
51
52 fn drive_unindexed<C>(self, consumer: C) -> C::Result
53 where
54 C: Consumer<U>,
55 {
56 bridge(self, consumer)
57 }
58
59 fn opt_len(&self) -> Option<usize> {
60 Some(self.len())
61 }
62}
63
64impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
65where
66 I: IndexedParallelIterator,
67 U: Send + Clone,
68 F: Fn(U, I::Item) -> U + Send + Sync,
69{
70 fn len(&self) -> usize {
71 self.base.len().div_ceil(self.chunk_size)
72 }
73
74 fn drive<C>(self, consumer: C) -> C::Result
75 where
76 C: Consumer<Self::Item>,
77 {
78 bridge(self, consumer)
79 }
80
81 fn with_producer<CB>(self, callback: CB) -> CB::Output
82 where
83 CB: ProducerCallback<Self::Item>,
84 {
85 let len = self.base.len();
86 return self.base.with_producer(Callback {
87 chunk_size: self.chunk_size,
88 len,
89 item: self.item,
90 fold_op: self.fold_op,
91 callback,
92 });
93
94 struct Callback<CB, T, F> {
95 chunk_size: usize,
96 len: usize,
97 item: T,
98 fold_op: F,
99 callback: CB,
100 }
101
102 impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
103 where
104 CB: ProducerCallback<U>,
105 U: Send + Clone,
106 F: Fn(U, T) -> U + Send + Sync,
107 {
108 type Output = CB::Output;
109
110 fn callback<P>(self, base: P) -> CB::Output
111 where
112 P: Producer<Item = T>,
113 {
114 let item = self.item;
115 let fold_op = &self.fold_op;
116 let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
117 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
118 self.callback.callback(producer)
119 }
120 }
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127 use std::ops::Add;
128
129 #[test]
130 fn check_fold_chunks_with() {
131 let words = "bishbashbosh!"
132 .chars()
133 .collect::<Vec<_>>()
134 .into_par_iter()
135 .fold_chunks_with(4, String::new(), |mut s, c| {
136 s.push(c);
137 s
138 })
139 .collect::<Vec<_>>();
140
141 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
142 }
143
144 fn sum<T, U>(x: T, y: U) -> T
146 where
147 T: Add<U, Output = T>,
148 {
149 x + y
150 }
151
152 #[test]
153 #[should_panic(expected = "chunk_size must not be zero")]
154 fn check_fold_chunks_zero_size() {
155 let _: Vec<i32> = vec![1, 2, 3]
156 .into_par_iter()
157 .fold_chunks_with(0, 0, sum)
158 .collect();
159 }
160
161 #[test]
162 fn check_fold_chunks_even_size() {
163 assert_eq!(
164 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
165 (1..10)
166 .into_par_iter()
167 .fold_chunks_with(3, 0, sum)
168 .collect::<Vec<i32>>()
169 );
170 }
171
172 #[test]
173 fn check_fold_chunks_with_empty() {
174 let v: Vec<i32> = vec![];
175 let expected: Vec<i32> = vec![];
176 assert_eq!(
177 expected,
178 v.into_par_iter()
179 .fold_chunks_with(2, 0, sum)
180 .collect::<Vec<i32>>()
181 );
182 }
183
184 #[test]
185 fn check_fold_chunks_len() {
186 assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
187 assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
188 assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
189 assert_eq!(1, [1].par_iter().fold_chunks_with(3, 0, sum).len());
190 assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
191 }
192
193 #[test]
194 fn check_fold_chunks_uneven() {
195 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
196 ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
197 (vec![1], 5, vec![1]),
198 ((0..4).collect(), 3, vec![1 + 2, 3]),
199 ];
200
201 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
202 let mut res: Vec<u32> = vec![];
203 v.par_iter()
204 .fold_chunks_with(n, 0, sum)
205 .collect_into_vec(&mut res);
206 assert_eq!(expected, res, "Case {i} failed");
207
208 res.truncate(0);
209 v.into_par_iter()
210 .fold_chunks_with(n, 0, sum)
211 .rev()
212 .collect_into_vec(&mut res);
213 assert_eq!(
214 expected.into_iter().rev().collect::<Vec<u32>>(),
215 res,
216 "Case {i} reversed failed"
217 );
218 }
219 }
220}