rayon/iter/
chain.rs

1use super::plumbing::*;
2use super::*;
3use rayon_core::join;
4use std::iter;
5
6/// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
7/// This struct is created by the [`chain()`] method on [`ParallelIterator`]
8///
9/// [`chain()`]: ParallelIterator::chain()
10#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
11#[derive(Debug, Clone)]
12pub struct Chain<A, B> {
13    a: A,
14    b: B,
15}
16
17impl<A, B> Chain<A, B> {
18    /// Creates a new `Chain` iterator.
19    pub(super) fn new(a: A, b: B) -> Self {
20        Chain { a, b }
21    }
22}
23
24impl<A, B> ParallelIterator for Chain<A, B>
25where
26    A: ParallelIterator,
27    B: ParallelIterator<Item = A::Item>,
28{
29    type Item = A::Item;
30
31    fn drive_unindexed<C>(self, consumer: C) -> C::Result
32    where
33        C: UnindexedConsumer<Self::Item>,
34    {
35        let Chain { a, b } = self;
36
37        // If we returned a value from our own `opt_len`, then the collect consumer in particular
38        // will balk at being treated like an actual `UnindexedConsumer`.  But when we do know the
39        // length, we can use `Consumer::split_at` instead, and this is still harmless for other
40        // truly-unindexed consumers too.
41        let (left, right, reducer) = if let Some(len) = a.opt_len() {
42            consumer.split_at(len)
43        } else {
44            let reducer = consumer.to_reducer();
45            (consumer.split_off_left(), consumer, reducer)
46        };
47
48        let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
49        reducer.reduce(a, b)
50    }
51
52    fn opt_len(&self) -> Option<usize> {
53        self.a.opt_len()?.checked_add(self.b.opt_len()?)
54    }
55}
56
57impl<A, B> IndexedParallelIterator for Chain<A, B>
58where
59    A: IndexedParallelIterator,
60    B: IndexedParallelIterator<Item = A::Item>,
61{
62    fn drive<C>(self, consumer: C) -> C::Result
63    where
64        C: Consumer<Self::Item>,
65    {
66        let Chain { a, b } = self;
67        let (left, right, reducer) = consumer.split_at(a.len());
68        let (a, b) = join(|| a.drive(left), || b.drive(right));
69        reducer.reduce(a, b)
70    }
71
72    fn len(&self) -> usize {
73        self.a.len().checked_add(self.b.len()).expect("overflow")
74    }
75
76    fn with_producer<CB>(self, callback: CB) -> CB::Output
77    where
78        CB: ProducerCallback<Self::Item>,
79    {
80        let a_len = self.a.len();
81        return self.a.with_producer(CallbackA {
82            callback,
83            a_len,
84            b: self.b,
85        });
86
87        struct CallbackA<CB, B> {
88            callback: CB,
89            a_len: usize,
90            b: B,
91        }
92
93        impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
94        where
95            B: IndexedParallelIterator,
96            CB: ProducerCallback<B::Item>,
97        {
98            type Output = CB::Output;
99
100            fn callback<A>(self, a_producer: A) -> Self::Output
101            where
102                A: Producer<Item = B::Item>,
103            {
104                self.b.with_producer(CallbackB {
105                    callback: self.callback,
106                    a_len: self.a_len,
107                    a_producer,
108                })
109            }
110        }
111
112        struct CallbackB<CB, A> {
113            callback: CB,
114            a_len: usize,
115            a_producer: A,
116        }
117
118        impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
119        where
120            A: Producer,
121            CB: ProducerCallback<A::Item>,
122        {
123            type Output = CB::Output;
124
125            fn callback<B>(self, b_producer: B) -> Self::Output
126            where
127                B: Producer<Item = A::Item>,
128            {
129                let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
130                self.callback.callback(producer)
131            }
132        }
133    }
134}
135
136// ////////////////////////////////////////////////////////////////////////
137
138struct ChainProducer<A, B>
139where
140    A: Producer,
141    B: Producer<Item = A::Item>,
142{
143    a_len: usize,
144    a: A,
145    b: B,
146}
147
148impl<A, B> ChainProducer<A, B>
149where
150    A: Producer,
151    B: Producer<Item = A::Item>,
152{
153    fn new(a_len: usize, a: A, b: B) -> Self {
154        ChainProducer { a_len, a, b }
155    }
156}
157
158impl<A, B> Producer for ChainProducer<A, B>
159where
160    A: Producer,
161    B: Producer<Item = A::Item>,
162{
163    type Item = A::Item;
164    type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
165
166    fn into_iter(self) -> Self::IntoIter {
167        ChainSeq::new(self.a.into_iter(), self.b.into_iter())
168    }
169
170    fn min_len(&self) -> usize {
171        Ord::max(self.a.min_len(), self.b.min_len())
172    }
173
174    fn max_len(&self) -> usize {
175        Ord::min(self.a.max_len(), self.b.max_len())
176    }
177
178    fn split_at(self, index: usize) -> (Self, Self) {
179        if index <= self.a_len {
180            let a_rem = self.a_len - index;
181            let (a_left, a_right) = self.a.split_at(index);
182            let (b_left, b_right) = self.b.split_at(0);
183            (
184                ChainProducer::new(index, a_left, b_left),
185                ChainProducer::new(a_rem, a_right, b_right),
186            )
187        } else {
188            let (a_left, a_right) = self.a.split_at(self.a_len);
189            let (b_left, b_right) = self.b.split_at(index - self.a_len);
190            (
191                ChainProducer::new(self.a_len, a_left, b_left),
192                ChainProducer::new(0, a_right, b_right),
193            )
194        }
195    }
196
197    fn fold_with<F>(self, mut folder: F) -> F
198    where
199        F: Folder<A::Item>,
200    {
201        folder = self.a.fold_with(folder);
202        if folder.full() {
203            folder
204        } else {
205            self.b.fold_with(folder)
206        }
207    }
208}
209
210// ////////////////////////////////////////////////////////////////////////
211
212/// Wrapper for `Chain` to implement `ExactSizeIterator`
213struct ChainSeq<A, B> {
214    chain: iter::Chain<A, B>,
215}
216
217impl<A, B> ChainSeq<A, B> {
218    fn new(a: A, b: B) -> ChainSeq<A, B>
219    where
220        A: ExactSizeIterator,
221        B: ExactSizeIterator<Item = A::Item>,
222    {
223        ChainSeq { chain: a.chain(b) }
224    }
225}
226
227impl<A, B> Iterator for ChainSeq<A, B>
228where
229    A: Iterator,
230    B: Iterator<Item = A::Item>,
231{
232    type Item = A::Item;
233
234    fn next(&mut self) -> Option<Self::Item> {
235        self.chain.next()
236    }
237
238    fn size_hint(&self) -> (usize, Option<usize>) {
239        self.chain.size_hint()
240    }
241}
242
243impl<A, B> ExactSizeIterator for ChainSeq<A, B>
244where
245    A: ExactSizeIterator,
246    B: ExactSizeIterator<Item = A::Item>,
247{
248}
249
250impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
251where
252    A: DoubleEndedIterator,
253    B: DoubleEndedIterator<Item = A::Item>,
254{
255    fn next_back(&mut self) -> Option<Self::Item> {
256        self.chain.next_back()
257    }
258}