derive_more_impl/fmt/
display.rs

1//! Implementation of [`fmt::Display`]-like derive macros.
2
3#[cfg(doc)]
4use std::fmt;
5
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote};
8use syn::{ext::IdentExt as _, parse_quote, spanned::Spanned as _};
9
10use crate::utils::{attr::ParseMultiple as _, Spanning};
11
12use super::{
13    trait_name_to_attribute_name, ContainerAttributes, ContainsGenericsExt as _,
14    FmtAttribute,
15};
16
17/// Expands a [`fmt::Display`]-like derive macro.
18///
19/// Available macros:
20/// - [`Binary`](fmt::Binary)
21/// - [`Display`](fmt::Display)
22/// - [`LowerExp`](fmt::LowerExp)
23/// - [`LowerHex`](fmt::LowerHex)
24/// - [`Octal`](fmt::Octal)
25/// - [`Pointer`](fmt::Pointer)
26/// - [`UpperExp`](fmt::UpperExp)
27/// - [`UpperHex`](fmt::UpperHex)
28pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenStream> {
29    let trait_name = normalize_trait_name(trait_name);
30    let attr_name = format_ident!("{}", trait_name_to_attribute_name(trait_name));
31
32    let attrs = ContainerAttributes::parse_attrs(&input.attrs, &attr_name)?
33        .map(Spanning::into_inner)
34        .unwrap_or_default();
35    let trait_ident = format_ident!("{trait_name}");
36    let ident = &input.ident;
37
38    let type_params = input
39        .generics
40        .params
41        .iter()
42        .filter_map(|p| match p {
43            syn::GenericParam::Type(t) => Some(&t.ident),
44            syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None,
45        })
46        .collect::<Vec<_>>();
47
48    let ctx: ExpansionCtx = (&attrs, &type_params, ident, &trait_ident, &attr_name);
49    let (bounds, body) = match &input.data {
50        syn::Data::Struct(s) => expand_struct(s, ctx),
51        syn::Data::Enum(e) => expand_enum(e, ctx),
52        syn::Data::Union(u) => expand_union(u, ctx),
53    }?;
54
55    let (impl_gens, ty_gens, where_clause) = {
56        let (impl_gens, ty_gens, where_clause) = input.generics.split_for_impl();
57        let mut where_clause = where_clause
58            .cloned()
59            .unwrap_or_else(|| parse_quote! { where });
60        where_clause.predicates.extend(bounds);
61        (impl_gens, ty_gens, where_clause)
62    };
63
64    Ok(quote! {
65        #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
66        #[automatically_derived]
67        impl #impl_gens derive_more::core::fmt::#trait_ident for #ident #ty_gens #where_clause {
68            fn fmt(
69                &self, __derive_more_f: &mut derive_more::core::fmt::Formatter<'_>
70            ) -> derive_more::core::fmt::Result {
71                #body
72            }
73        }
74    })
75}
76
77/// Type alias for an expansion context:
78/// - [`ContainerAttributes`].
79/// - Type parameters. Slice of [`syn::Ident`].
80/// - Struct/enum/union [`syn::Ident`].
81/// - Derived trait [`syn::Ident`].
82/// - Attribute name [`syn::Ident`].
83///
84/// [`syn::Ident`]: struct@syn::Ident
85type ExpansionCtx<'a> = (
86    &'a ContainerAttributes,
87    &'a [&'a syn::Ident],
88    &'a syn::Ident,
89    &'a syn::Ident,
90    &'a syn::Ident,
91);
92
93/// Expands a [`fmt::Display`]-like derive macro for the provided struct.
94fn expand_struct(
95    s: &syn::DataStruct,
96    (attrs, type_params, ident, trait_ident, _): ExpansionCtx<'_>,
97) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
98    let s = Expansion {
99        shared_attr: None,
100        attrs,
101        fields: &s.fields,
102        type_params,
103        trait_ident,
104        ident,
105    };
106    let bounds = s.generate_bounds();
107    let body = s.generate_body()?;
108
109    let vars = s.fields.iter().enumerate().map(|(i, f)| {
110        let var = f.ident.clone().unwrap_or_else(|| format_ident!("_{i}"));
111        let member = f
112            .ident
113            .clone()
114            .map_or_else(|| syn::Member::Unnamed(i.into()), syn::Member::Named);
115        quote! {
116            let #var = &self.#member;
117        }
118    });
119
120    let body = quote! {
121        #( #vars )*
122        #body
123    };
124
125    Ok((bounds, body))
126}
127
128/// Expands a [`fmt`]-like derive macro for the provided enum.
129fn expand_enum(
130    e: &syn::DataEnum,
131    (container_attrs, type_params, _, trait_ident, attr_name): ExpansionCtx<'_>,
132) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
133    if let Some(shared_fmt) = &container_attrs.fmt {
134        if shared_fmt
135            .placeholders_by_arg("_variant")
136            .any(|p| p.has_modifiers || p.trait_name != "Display")
137        {
138            // TODO: This limitation can be lifted, by analyzing the `shared_fmt` deeper and using
139            //       `&dyn fmt::TraitName` for transparency instead of just `format_args!()` in the
140            //       expansion.
141            return Err(syn::Error::new(
142                shared_fmt.span(),
143                "shared format `_variant` placeholder cannot contain format specifiers",
144            ));
145        }
146    }
147
148    let (bounds, match_arms) = e.variants.iter().try_fold(
149        (Vec::new(), TokenStream::new()),
150        |(mut bounds, mut arms), variant| {
151            let attrs = ContainerAttributes::parse_attrs(&variant.attrs, attr_name)?
152                .map(Spanning::into_inner)
153                .unwrap_or_default();
154            let ident = &variant.ident;
155
156            if attrs.fmt.is_none()
157                && variant.fields.is_empty()
158                && attr_name != "display"
159            {
160                return Err(syn::Error::new(
161                    e.variants.span(),
162                    format!(
163                        "implicit formatting of unit enum variant is supported only for `Display` \
164                         macro, use `#[{attr_name}(\"...\")]` to explicitly specify the formatting",
165                    ),
166                ));
167            }
168
169            let v = Expansion {
170                shared_attr: container_attrs.fmt.as_ref(),
171                attrs: &attrs,
172                fields: &variant.fields,
173                type_params,
174                trait_ident,
175                ident,
176            };
177            let arm_body = v.generate_body()?;
178            bounds.extend(v.generate_bounds());
179
180            let fields_idents =
181                variant.fields.iter().enumerate().map(|(i, f)| {
182                    f.ident.clone().unwrap_or_else(|| format_ident!("_{i}"))
183                });
184            let matcher = match variant.fields {
185                syn::Fields::Named(_) => {
186                    quote! { Self::#ident { #( #fields_idents ),* } }
187                }
188                syn::Fields::Unnamed(_) => {
189                    quote! { Self::#ident ( #( #fields_idents ),* ) }
190                }
191                syn::Fields::Unit => quote! { Self::#ident },
192            };
193
194            arms.extend([quote! { #matcher => { #arm_body }, }]);
195
196            Ok::<_, syn::Error>((bounds, arms))
197        },
198    )?;
199
200    let body = match_arms
201        .is_empty()
202        .then(|| quote! { match *self {} })
203        .unwrap_or_else(|| quote! { match self { #match_arms } });
204
205    Ok((bounds, body))
206}
207
208/// Expands a [`fmt::Display`]-like derive macro for the provided union.
209fn expand_union(
210    u: &syn::DataUnion,
211    (attrs, _, _, _, attr_name): ExpansionCtx<'_>,
212) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
213    let fmt = &attrs.fmt.as_ref().ok_or_else(|| {
214        syn::Error::new(
215            u.fields.span(),
216            format!("unions must have `#[{attr_name}(\"...\", ...)]` attribute"),
217        )
218    })?;
219
220    Ok((
221        attrs.bounds.0.clone().into_iter().collect(),
222        quote! { derive_more::core::write!(__derive_more_f, #fmt) },
223    ))
224}
225
226/// Helper struct to generate [`Display::fmt()`] implementation body and trait
227/// bounds for a struct or an enum variant.
228///
229/// [`Display::fmt()`]: fmt::Display::fmt()
230#[derive(Debug)]
231struct Expansion<'a> {
232    /// [`FmtAttribute`] shared between all variants of an enum.
233    ///
234    /// [`None`] for a struct.
235    shared_attr: Option<&'a FmtAttribute>,
236
237    /// Derive macro [`ContainerAttributes`].
238    attrs: &'a ContainerAttributes,
239
240    /// Struct or enum [`syn::Ident`].
241    ///
242    /// [`syn::Ident`]: struct@syn::Ident
243    ident: &'a syn::Ident,
244
245    /// Struct or enum [`syn::Fields`].
246    fields: &'a syn::Fields,
247
248    /// Type parameters in this struct or enum.
249    type_params: &'a [&'a syn::Ident],
250
251    /// [`fmt`] trait [`syn::Ident`].
252    ///
253    /// [`syn::Ident`]: struct@syn::Ident
254    trait_ident: &'a syn::Ident,
255}
256
257impl Expansion<'_> {
258    /// Checks and indicates whether a top-level shared [`FmtAttribute`] is present in this
259    /// [`Expansion`], and whether it has wrapping logic (e.g. uses `_variant` placeholder).
260    fn shared_attr_info(&self) -> (bool, bool) {
261        let shared_attr_contains_variant = self
262            .shared_attr
263            .map_or(true, |attr| attr.contains_arg("_variant"));
264        // If `shared_attr` is a transparent call to `_variant`, then we consider it being absent.
265        let has_shared_attr = self.shared_attr.is_some_and(|attr| {
266            attr.transparent_call().map_or(true, |(_, called_trait)| {
267                &called_trait != self.trait_ident || !shared_attr_contains_variant
268            })
269        });
270        (
271            has_shared_attr,
272            has_shared_attr && shared_attr_contains_variant,
273        )
274    }
275
276    /// Generates [`Display::fmt()`] implementation for a struct or an enum variant.
277    ///
278    /// # Errors
279    ///
280    /// In case [`FmtAttribute`] is [`None`] and [`syn::Fields`] length is greater than 1.
281    ///
282    /// [`Display::fmt()`]: fmt::Display::fmt()
283    fn generate_body(&self) -> syn::Result<TokenStream> {
284        let mut body = TokenStream::new();
285
286        let (has_shared_attr, shared_attr_is_wrapping) = self.shared_attr_info();
287
288        let wrap_into_shared_attr = match &self.attrs.fmt {
289            Some(fmt) => {
290                body = if shared_attr_is_wrapping {
291                    let deref_args = fmt.additional_deref_args(self.fields);
292
293                    quote! { &derive_more::core::format_args!(#fmt, #(#deref_args),*) }
294                } else if let Some((expr, trait_ident)) =
295                    fmt.transparent_call_on_fields(self.fields)
296                {
297                    quote! { derive_more::core::fmt::#trait_ident::fmt(#expr, __derive_more_f) }
298                } else {
299                    let deref_args = fmt.additional_deref_args(self.fields);
300
301                    quote! { derive_more::core::write!(__derive_more_f, #fmt, #(#deref_args),*) }
302                };
303                shared_attr_is_wrapping
304            }
305            None => {
306                if shared_attr_is_wrapping || !has_shared_attr {
307                    body = if self.fields.is_empty() {
308                        let ident_str = self.ident.unraw().to_string();
309
310                        if shared_attr_is_wrapping {
311                            quote! { #ident_str }
312                        } else {
313                            quote! { __derive_more_f.write_str(#ident_str) }
314                        }
315                    } else if self.fields.len() == 1 {
316                        let field = self
317                            .fields
318                            .iter()
319                            .next()
320                            .unwrap_or_else(|| unreachable!("count() == 1"));
321                        let ident =
322                            field.ident.clone().unwrap_or_else(|| format_ident!("_0"));
323                        let trait_ident = self.trait_ident;
324
325                        if shared_attr_is_wrapping {
326                            let placeholder =
327                                trait_name_to_default_placeholder_literal(trait_ident);
328
329                            quote! { &derive_more::core::format_args!(#placeholder, #ident) }
330                        } else {
331                            quote! {
332                                derive_more::core::fmt::#trait_ident::fmt(#ident, __derive_more_f)
333                            }
334                        }
335                    } else {
336                        return Err(syn::Error::new(
337                            self.fields.span(),
338                            format!(
339                                "struct or enum variant with more than 1 field must have \
340                                 `#[{}(\"...\", ...)]` attribute",
341                                trait_name_to_attribute_name(self.trait_ident),
342                            ),
343                        ));
344                    };
345                }
346                has_shared_attr
347            }
348        };
349        if wrap_into_shared_attr {
350            if let Some(shared_fmt) = &self.shared_attr {
351                let deref_args = shared_fmt.additional_deref_args(self.fields);
352
353                let shared_body = if let Some((expr, trait_ident)) =
354                    shared_fmt.transparent_call_on_fields(self.fields)
355                {
356                    quote! { derive_more::core::fmt::#trait_ident::fmt(#expr, __derive_more_f) }
357                } else {
358                    quote! {
359                        derive_more::core::write!(__derive_more_f, #shared_fmt, #(#deref_args),*)
360                    }
361                };
362
363                body = if body.is_empty() {
364                    shared_body
365                } else {
366                    quote! { match #body { _variant => #shared_body } }
367                }
368            }
369        }
370
371        Ok(body)
372    }
373
374    /// Generates trait bounds for a struct or an enum variant.
375    fn generate_bounds(&self) -> Vec<syn::WherePredicate> {
376        let mut bounds = vec![];
377
378        let (has_shared_attr, shared_attr_is_wrapping) = self.shared_attr_info();
379
380        let mix_shared_attr_bounds = match &self.attrs.fmt {
381            Some(attr) => {
382                bounds.extend(
383                    attr.bounded_types(self.fields)
384                        .filter_map(|(ty, trait_name)| {
385                            if !ty.contains_generics(self.type_params) {
386                                return None;
387                            }
388                            let trait_ident = format_ident!("{trait_name}");
389
390                            Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
391                        })
392                        .chain(self.attrs.bounds.0.clone()),
393                );
394                shared_attr_is_wrapping
395            }
396            None => {
397                if shared_attr_is_wrapping || !has_shared_attr {
398                    bounds.extend(self.fields.iter().next().and_then(|f| {
399                        let ty = &f.ty;
400                        if !ty.contains_generics(self.type_params) {
401                            return None;
402                        }
403                        let trait_ident = &self.trait_ident;
404                        Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
405                    }));
406                }
407                has_shared_attr
408            }
409        };
410        if mix_shared_attr_bounds {
411            bounds.extend(
412                self.shared_attr
413                    .as_ref()
414                    .unwrap()
415                    .bounded_types(self.fields)
416                    .filter_map(|(ty, trait_name)| {
417                        if !ty.contains_generics(self.type_params) {
418                            return None;
419                        }
420                        let trait_ident = format_ident!("{trait_name}");
421
422                        Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
423                    }),
424            );
425        }
426
427        bounds
428    }
429}
430
431/// Matches the provided derive macro `name` to appropriate actual trait name.
432fn normalize_trait_name(name: &str) -> &'static str {
433    match name {
434        "Binary" => "Binary",
435        "Display" => "Display",
436        "LowerExp" => "LowerExp",
437        "LowerHex" => "LowerHex",
438        "Octal" => "Octal",
439        "Pointer" => "Pointer",
440        "UpperExp" => "UpperExp",
441        "UpperHex" => "UpperHex",
442        _ => unimplemented!(),
443    }
444}
445
446/// Matches the provided [`fmt`] trait `name` to its default formatting placeholder.
447fn trait_name_to_default_placeholder_literal(name: &syn::Ident) -> &'static str {
448    match () {
449        _ if name == "Binary" => "{:b}",
450        _ if name == "Debug" => "{:?}",
451        _ if name == "Display" => "{}",
452        _ if name == "LowerExp" => "{:e}",
453        _ if name == "LowerHex" => "{:x}",
454        _ if name == "Octal" => "{:o}",
455        _ if name == "Pointer" => "{:p}",
456        _ if name == "UpperExp" => "{:E}",
457        _ if name == "UpperHex" => "{:X}",
458        _ => unimplemented!(),
459    }
460}