derive_more_impl/
parsing.rs

1//! Common parsing utilities for derive macros.
2//!
3//! Fair parsing of [`syn::Expr`] requires [`syn`]'s `full` feature to be enabled, which unnecessary
4//! increases compile times. As we don't have complex AST manipulation, usually requiring only
5//! understanding where syntax item begins and ends, simpler manual parsing is implemented.
6
7use proc_macro2::{Spacing, TokenStream};
8use quote::ToTokens;
9use syn::{
10    buffer::Cursor,
11    parse::{Parse, ParseStream},
12};
13
14/// [`syn::Expr`] [`Parse`]ing polyfill.
15#[derive(Clone, Debug)]
16pub(crate) enum Expr {
17    /// [`syn::Expr::Path`] of length 1 [`Parse`]ing polyfill.
18    Ident(syn::Ident),
19
20    /// Every other [`syn::Expr`] variant.
21    Other(TokenStream),
22}
23
24impl Expr {
25    /// Returns a [`syn::Ident`] in case this [`Expr`] is represented only by it.
26    ///
27    /// [`syn::Ident`]: struct@syn::Ident
28    pub(crate) fn ident(&self) -> Option<&syn::Ident> {
29        match self {
30            Self::Ident(ident) => Some(ident),
31            Self::Other(_) => None,
32        }
33    }
34}
35
36impl From<syn::Ident> for Expr {
37    fn from(ident: syn::Ident) -> Self {
38        Self::Ident(ident)
39    }
40}
41
42impl Parse for Expr {
43    fn parse(input: ParseStream) -> syn::Result<Self> {
44        if let Ok(ident) = input.step(|c| {
45            c.ident()
46                .filter(|(_, c)| c.eof() || punct(',')(*c).is_some())
47                .ok_or_else(|| syn::Error::new(c.span(), "expected `ident(,|eof)`"))
48        }) {
49            Ok(Self::Ident(ident))
50        } else {
51            input.step(|c| {
52                take_until1(
53                    alt([
54                        &mut seq([
55                            &mut path_sep,
56                            &mut balanced_pair(punct('<'), punct('>')),
57                        ]),
58                        &mut seq([
59                            &mut balanced_pair(punct('<'), punct('>')),
60                            &mut path_sep,
61                        ]),
62                        &mut balanced_pair(punct('|'), punct('|')),
63                        &mut token_tree,
64                    ]),
65                    punct(','),
66                )(*c)
67                .map(|(stream, cursor)| (Self::Other(stream), cursor))
68                .ok_or_else(|| syn::Error::new(c.span(), "failed to parse expression"))
69            })
70        }
71    }
72}
73
74impl PartialEq<syn::Ident> for Expr {
75    fn eq(&self, other: &syn::Ident) -> bool {
76        self.ident().is_some_and(|i| i == other)
77    }
78}
79
80impl ToTokens for Expr {
81    fn to_tokens(&self, tokens: &mut TokenStream) {
82        match self {
83            Self::Ident(ident) => ident.to_tokens(tokens),
84            Self::Other(other) => other.to_tokens(tokens),
85        }
86    }
87}
88
89/// Result of parsing.
90type ParsingResult<'a> = Option<(TokenStream, Cursor<'a>)>;
91
92/// Tries to parse a [`token::PathSep`].
93///
94/// [`token::PathSep`]: struct@syn::token::PathSep
95pub fn path_sep(c: Cursor<'_>) -> ParsingResult<'_> {
96    seq([
97        &mut punct_with_spacing(':', Spacing::Joint),
98        &mut punct(':'),
99    ])(c)
100}
101
102/// Tries to parse a [`punct`] with [`Spacing`].
103pub fn punct_with_spacing(
104    p: char,
105    spacing: Spacing,
106) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
107    move |c| {
108        c.punct().and_then(|(punct, c)| {
109            (punct.as_char() == p && punct.spacing() == spacing)
110                .then(|| (punct.into_token_stream(), c))
111        })
112    }
113}
114
115/// Tries to parse a [`Punct`].
116///
117/// [`Punct`]: proc_macro2::Punct
118pub fn punct(p: char) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
119    move |c| {
120        c.punct().and_then(|(punct, c)| {
121            (punct.as_char() == p).then(|| (punct.into_token_stream(), c))
122        })
123    }
124}
125
126/// Tries to parse any [`TokenTree`].
127///
128/// [`TokenTree`]: proc_macro2::TokenTree
129pub fn token_tree(c: Cursor<'_>) -> ParsingResult<'_> {
130    c.token_tree().map(|(tt, c)| (tt.into_token_stream(), c))
131}
132
133/// Parses until balanced amount of `open` and `close` or eof.
134///
135/// [`Cursor`] should be pointing **right after** the first `open`ing.
136pub fn balanced_pair(
137    mut open: impl FnMut(Cursor<'_>) -> ParsingResult<'_>,
138    mut close: impl FnMut(Cursor<'_>) -> ParsingResult<'_>,
139) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
140    move |c| {
141        let (mut out, mut c) = open(c)?;
142        let mut count = 1;
143
144        while count != 0 {
145            let (stream, cursor) = if let Some(closing) = close(c) {
146                count -= 1;
147                closing
148            } else if let Some(opening) = open(c) {
149                count += 1;
150                opening
151            } else {
152                let (tt, c) = c.token_tree()?;
153                (tt.into_token_stream(), c)
154            };
155            out.extend(stream);
156            c = cursor;
157        }
158
159        Some((out, c))
160    }
161}
162
163/// Tries to execute the provided sequence of `parsers`.
164pub fn seq<const N: usize>(
165    mut parsers: [&mut dyn FnMut(Cursor<'_>) -> ParsingResult<'_>; N],
166) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> + '_ {
167    move |c| {
168        parsers.iter_mut().try_fold(
169            (TokenStream::new(), c),
170            |(mut out, mut c), parser| {
171                let (stream, cursor) = parser(c)?;
172                out.extend(stream);
173                c = cursor;
174                Some((out, c))
175            },
176        )
177    }
178}
179
180/// Tries to execute the first successful parser.
181pub fn alt<const N: usize>(
182    mut parsers: [&mut dyn FnMut(Cursor<'_>) -> ParsingResult<'_>; N],
183) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> + '_ {
184    move |c| parsers.iter_mut().find_map(|parser| parser(c))
185}
186
187/// Parses with `basic` while `until` fails. Returns [`None`] in case
188/// `until` succeeded initially or `basic` never succeeded. Doesn't consume
189/// tokens parsed by `until`.
190pub fn take_until1<P, U>(
191    mut parser: P,
192    mut until: U,
193) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_>
194where
195    P: FnMut(Cursor<'_>) -> ParsingResult<'_>,
196    U: FnMut(Cursor<'_>) -> ParsingResult<'_>,
197{
198    move |mut cursor| {
199        let mut out = TokenStream::new();
200        let mut parsed = false;
201
202        loop {
203            if cursor.eof() || until(cursor).is_some() {
204                return parsed.then_some((out, cursor));
205            }
206
207            let (stream, c) = parser(cursor)?;
208            out.extend(stream);
209            cursor = c;
210            parsed = true;
211        }
212    }
213}
214
215#[cfg(test)]
216mod spec {
217    use std::{fmt::Debug, str::FromStr};
218
219    use itertools::Itertools as _;
220    use proc_macro2::TokenStream;
221    use quote::ToTokens;
222    use syn::{
223        parse::{Parse, Parser as _},
224        punctuated::Punctuated,
225        token::Comma,
226    };
227
228    use super::Expr;
229
230    fn assert<'a, T: Debug + Parse + ToTokens>(
231        input: &'a str,
232        parsed: impl AsRef<[&'a str]>,
233    ) {
234        let parsed = parsed.as_ref();
235        let punctuated = Punctuated::<T, Comma>::parse_terminated
236            .parse2(TokenStream::from_str(input).unwrap())
237            .unwrap();
238
239        assert_eq!(
240            parsed.len(),
241            punctuated.len(),
242            "Wrong length\n\
243             Expected: {parsed:?}\n\
244             Found: {punctuated:?}",
245        );
246
247        punctuated
248            .iter()
249            .map(|ty| ty.to_token_stream().to_string())
250            .zip(parsed)
251            .enumerate()
252            .for_each(|(i, (found, expected))| {
253                assert_eq!(
254                    *expected, &found,
255                    "Mismatch at index {i}\n\
256                     Expected: {parsed:?}\n\
257                     Found: {punctuated:?}",
258                );
259            });
260    }
261
262    mod expr {
263        use super::*;
264
265        #[test]
266        fn cases() {
267            let cases = [
268                "ident",
269                "[a , b , c , d]",
270                "counter += 1",
271                "async { fut . await }",
272                "a < b",
273                "a > b",
274                "{ let x = (a , b) ; }",
275                "invoke (a , b)",
276                "foo as f64",
277                "| a , b | a + b",
278                "obj . k",
279                "for pat in expr { break pat ; }",
280                "if expr { true } else { false }",
281                "vector [2]",
282                "1",
283                "\"foo\"",
284                "loop { break i ; }",
285                "format ! (\"{}\" , q)",
286                "match n { Some (n) => { } , None => { } }",
287                "x . foo ::< T > (a , b)",
288                "x . foo ::< T < [T < T >; if a < b { 1 } else { 2 }] >, { a < b } > (a , b)",
289                "(a + b)",
290                "i32 :: MAX",
291                "1 .. 2",
292                "& a",
293                "[0u8 ; N]",
294                "(a , b , c , d)",
295                "< Ty as Trait > :: T",
296                "< Ty < Ty < T >, { a < b } > as Trait < T > > :: T",
297            ];
298
299            assert::<Expr>("", []);
300            for i in 1..4 {
301                for permutations in cases.into_iter().permutations(i) {
302                    let mut input = permutations.clone().join(",");
303                    assert::<Expr>(&input, &permutations);
304                    input.push(',');
305                    assert::<Expr>(&input, &permutations);
306                }
307            }
308        }
309    }
310}