1use proc_macro2::{Spacing, TokenStream};
8use quote::ToTokens;
9use syn::{
10 buffer::Cursor,
11 parse::{Parse, ParseStream},
12};
13
14#[derive(Clone, Debug)]
16pub(crate) enum Expr {
17 Ident(syn::Ident),
19
20 Other(TokenStream),
22}
23
24impl Expr {
25 pub(crate) fn ident(&self) -> Option<&syn::Ident> {
29 match self {
30 Self::Ident(ident) => Some(ident),
31 Self::Other(_) => None,
32 }
33 }
34}
35
36impl From<syn::Ident> for Expr {
37 fn from(ident: syn::Ident) -> Self {
38 Self::Ident(ident)
39 }
40}
41
42impl Parse for Expr {
43 fn parse(input: ParseStream) -> syn::Result<Self> {
44 if let Ok(ident) = input.step(|c| {
45 c.ident()
46 .filter(|(_, c)| c.eof() || punct(',')(*c).is_some())
47 .ok_or_else(|| syn::Error::new(c.span(), "expected `ident(,|eof)`"))
48 }) {
49 Ok(Self::Ident(ident))
50 } else {
51 input.step(|c| {
52 take_until1(
53 alt([
54 &mut seq([
55 &mut path_sep,
56 &mut balanced_pair(punct('<'), punct('>')),
57 ]),
58 &mut seq([
59 &mut balanced_pair(punct('<'), punct('>')),
60 &mut path_sep,
61 ]),
62 &mut balanced_pair(punct('|'), punct('|')),
63 &mut token_tree,
64 ]),
65 punct(','),
66 )(*c)
67 .map(|(stream, cursor)| (Self::Other(stream), cursor))
68 .ok_or_else(|| syn::Error::new(c.span(), "failed to parse expression"))
69 })
70 }
71 }
72}
73
74impl PartialEq<syn::Ident> for Expr {
75 fn eq(&self, other: &syn::Ident) -> bool {
76 self.ident().is_some_and(|i| i == other)
77 }
78}
79
80impl ToTokens for Expr {
81 fn to_tokens(&self, tokens: &mut TokenStream) {
82 match self {
83 Self::Ident(ident) => ident.to_tokens(tokens),
84 Self::Other(other) => other.to_tokens(tokens),
85 }
86 }
87}
88
89type ParsingResult<'a> = Option<(TokenStream, Cursor<'a>)>;
91
92pub fn path_sep(c: Cursor<'_>) -> ParsingResult<'_> {
96 seq([
97 &mut punct_with_spacing(':', Spacing::Joint),
98 &mut punct(':'),
99 ])(c)
100}
101
102pub fn punct_with_spacing(
104 p: char,
105 spacing: Spacing,
106) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
107 move |c| {
108 c.punct().and_then(|(punct, c)| {
109 (punct.as_char() == p && punct.spacing() == spacing)
110 .then(|| (punct.into_token_stream(), c))
111 })
112 }
113}
114
115pub fn punct(p: char) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
119 move |c| {
120 c.punct().and_then(|(punct, c)| {
121 (punct.as_char() == p).then(|| (punct.into_token_stream(), c))
122 })
123 }
124}
125
126pub fn token_tree(c: Cursor<'_>) -> ParsingResult<'_> {
130 c.token_tree().map(|(tt, c)| (tt.into_token_stream(), c))
131}
132
133pub fn balanced_pair(
137 mut open: impl FnMut(Cursor<'_>) -> ParsingResult<'_>,
138 mut close: impl FnMut(Cursor<'_>) -> ParsingResult<'_>,
139) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> {
140 move |c| {
141 let (mut out, mut c) = open(c)?;
142 let mut count = 1;
143
144 while count != 0 {
145 let (stream, cursor) = if let Some(closing) = close(c) {
146 count -= 1;
147 closing
148 } else if let Some(opening) = open(c) {
149 count += 1;
150 opening
151 } else {
152 let (tt, c) = c.token_tree()?;
153 (tt.into_token_stream(), c)
154 };
155 out.extend(stream);
156 c = cursor;
157 }
158
159 Some((out, c))
160 }
161}
162
163pub fn seq<const N: usize>(
165 mut parsers: [&mut dyn FnMut(Cursor<'_>) -> ParsingResult<'_>; N],
166) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> + '_ {
167 move |c| {
168 parsers.iter_mut().try_fold(
169 (TokenStream::new(), c),
170 |(mut out, mut c), parser| {
171 let (stream, cursor) = parser(c)?;
172 out.extend(stream);
173 c = cursor;
174 Some((out, c))
175 },
176 )
177 }
178}
179
180pub fn alt<const N: usize>(
182 mut parsers: [&mut dyn FnMut(Cursor<'_>) -> ParsingResult<'_>; N],
183) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_> + '_ {
184 move |c| parsers.iter_mut().find_map(|parser| parser(c))
185}
186
187pub fn take_until1<P, U>(
191 mut parser: P,
192 mut until: U,
193) -> impl FnMut(Cursor<'_>) -> ParsingResult<'_>
194where
195 P: FnMut(Cursor<'_>) -> ParsingResult<'_>,
196 U: FnMut(Cursor<'_>) -> ParsingResult<'_>,
197{
198 move |mut cursor| {
199 let mut out = TokenStream::new();
200 let mut parsed = false;
201
202 loop {
203 if cursor.eof() || until(cursor).is_some() {
204 return parsed.then_some((out, cursor));
205 }
206
207 let (stream, c) = parser(cursor)?;
208 out.extend(stream);
209 cursor = c;
210 parsed = true;
211 }
212 }
213}
214
215#[cfg(test)]
216mod spec {
217 use std::{fmt::Debug, str::FromStr};
218
219 use itertools::Itertools as _;
220 use proc_macro2::TokenStream;
221 use quote::ToTokens;
222 use syn::{
223 parse::{Parse, Parser as _},
224 punctuated::Punctuated,
225 token::Comma,
226 };
227
228 use super::Expr;
229
230 fn assert<'a, T: Debug + Parse + ToTokens>(
231 input: &'a str,
232 parsed: impl AsRef<[&'a str]>,
233 ) {
234 let parsed = parsed.as_ref();
235 let punctuated = Punctuated::<T, Comma>::parse_terminated
236 .parse2(TokenStream::from_str(input).unwrap())
237 .unwrap();
238
239 assert_eq!(
240 parsed.len(),
241 punctuated.len(),
242 "Wrong length\n\
243 Expected: {parsed:?}\n\
244 Found: {punctuated:?}",
245 );
246
247 punctuated
248 .iter()
249 .map(|ty| ty.to_token_stream().to_string())
250 .zip(parsed)
251 .enumerate()
252 .for_each(|(i, (found, expected))| {
253 assert_eq!(
254 *expected, &found,
255 "Mismatch at index {i}\n\
256 Expected: {parsed:?}\n\
257 Found: {punctuated:?}",
258 );
259 });
260 }
261
262 mod expr {
263 use super::*;
264
265 #[test]
266 fn cases() {
267 let cases = [
268 "ident",
269 "[a , b , c , d]",
270 "counter += 1",
271 "async { fut . await }",
272 "a < b",
273 "a > b",
274 "{ let x = (a , b) ; }",
275 "invoke (a , b)",
276 "foo as f64",
277 "| a , b | a + b",
278 "obj . k",
279 "for pat in expr { break pat ; }",
280 "if expr { true } else { false }",
281 "vector [2]",
282 "1",
283 "\"foo\"",
284 "loop { break i ; }",
285 "format ! (\"{}\" , q)",
286 "match n { Some (n) => { } , None => { } }",
287 "x . foo ::< T > (a , b)",
288 "x . foo ::< T < [T < T >; if a < b { 1 } else { 2 }] >, { a < b } > (a , b)",
289 "(a + b)",
290 "i32 :: MAX",
291 "1 .. 2",
292 "& a",
293 "[0u8 ; N]",
294 "(a , b , c , d)",
295 "< Ty as Trait > :: T",
296 "< Ty < Ty < T >, { a < b } > as Trait < T > > :: T",
297 ];
298
299 assert::<Expr>("", []);
300 for i in 1..4 {
301 for permutations in cases.into_iter().permutations(i) {
302 let mut input = permutations.clone().join(",");
303 assert::<Expr>(&input, &permutations);
304 input.push(',');
305 assert::<Expr>(&input, &permutations);
306 }
307 }
308 }
309 }
310}