wright/source_tracking/
source.rs

1//! Structure and implementation for storing source code items used by the wright compiler, including
2//! source files from disk, source strings used in test cases, and source strings created at
3//! run-time by an API consumer.
4
5use super::SourceRef;
6use super::{filename::FileName, fragment::Fragment, immutable_string::ImmutableString};
7use std::io;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::{fs::File, sync::Arc};
11
12#[cfg(feature = "file_memmap")]
13use std::{sync::mpsc, thread, time::Duration};
14
15#[cfg(feature = "file_memmap")]
16use fs4::fs_std::FileExt;
17
18#[cfg(feature = "file_memmap")]
19use memmap2::Mmap;
20
21#[cfg(feature = "file_memmap")]
22use crate::reporting::Diagnostic;
23
24/// Amount of time before we should warn the user about locking the file taking too long.
25#[cfg(feature = "file_memmap")]
26pub const FILE_LOCK_WARNING_TIME: Duration = Duration::from_secs(5);
27
28/// The global [Source::id] generator.
29///
30/// This is just a global [u64] that gets incremented everytime a new source is instantiated.
31static SOURCE_ID_GENERATOR: AtomicU64 = AtomicU64::new(1);
32
33/// A process-unique source id, that is atomically generated and assigned to each [Source] on creation.
34#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
35pub struct SourceId(u64);
36
37/// A full source. This is usually a file, but may also be passed in the form of a string for testing.  
38#[derive(Debug)]
39pub struct Source {
40    /// Globally (process-wide) unique [Source] ID.
41    ///
42    /// It is fequently useful to have a consistient way to sort sources and check for equality between sources.
43    /// This cannot be done with the [Source::name] since that can be [FileName::None], and checking for equality
44    /// of content can be an expensive process.
45    ///
46    /// The id of a [Source] is an identifier that's globally unique for the runtime of the program, and is assigned to
47    /// the [Source] when it is instantiated.
48    pub id: SourceId,
49
50    /// The name of this source file.
51    name: FileName,
52
53    /// The content of this source file.
54    source: ImmutableString,
55
56    /// A list of byte indicies into the [Source::source] indicating where lines starts.
57    line_starts: Vec<usize>,
58}
59
60impl Source {
61    /// Construct a new [Source].
62    fn new(name: FileName, source: ImmutableString) -> Self {
63        Source {
64            // I believe we can use relaxed ordering here, since as long as all operations are atomic,
65            // we're not really worried about another thread's `fetch_add` being re-ordered before this one, since
66            // neither will get the same number.
67            id: SourceId(SOURCE_ID_GENERATOR.fetch_add(1, Ordering::Relaxed)),
68            name,
69            line_starts: source.line_starts().collect(),
70            source,
71        }
72    }
73
74    /// Create a [Source] using a heap allocated [String].
75    pub fn new_from_string(name: FileName, source: String) -> Self {
76        Source::new(name, ImmutableString::new_owned(source.into_boxed_str()))
77    }
78
79    /// Create a [Source] from a [`&'static str`].
80    ///
81    /// [`&'static str`]: str
82    pub fn new_from_static_str(name: FileName, source: &'static str) -> Self {
83        Source::new(name, ImmutableString::new_static(source))
84    }
85
86    /// Attempt to memory map a file from the disk into a [Source].
87    /// This will likely be faster than reading the file in some cases, and almost always more memory efficient.
88    ///
89    /// This requires the "file_memmap" feature.
90    #[cfg(feature = "file_memmap")]
91    pub fn new_mapped_from_disk(path: PathBuf) -> io::Result<Self> {
92        use crate::source_tracking::SourceMap;
93
94        // Make a one-off enum here to use for channel messages.
95        enum ChannelMessage {
96            /// The file was successfully locked.
97            FileLocked(File),
98            /// There was an error locking the file.
99            LockingError(io::Error),
100            /// File is taking a long time to lock.
101            FiveSecondWarning,
102        }
103
104        // Open the file for reading.
105        let file: File = File::open(&path)?;
106
107        // Create two threads and a mpsc channel for warning the user if
108        // locking the file takes longer than 5 seconds.
109        let (tx, rx) = mpsc::sync_channel::<ChannelMessage>(0);
110        let timeout_tx = tx.clone();
111
112        // Thread to lock the file
113        thread::spawn(move || match file.lock_exclusive() {
114            Ok(_) => tx.send(ChannelMessage::FileLocked(file)),
115            Err(err) => tx.send(ChannelMessage::LockingError(err)),
116        });
117
118        // Thread to warn user if it takes too long.
119        thread::spawn(move || {
120            thread::sleep(FILE_LOCK_WARNING_TIME);
121            timeout_tx.send(ChannelMessage::FiveSecondWarning)
122        });
123
124        // Use an infinite loop to make sure we recieve all the messages from the senders.
125        loop {
126            match rx.recv() {
127                // Emit the diagnostic for the 5-second warning.
128                Ok(ChannelMessage::FiveSecondWarning) => {
129                    // Make the diagnostic to show to the user.
130                    let message = format!(
131                        "Getting a file lock on {} has taken more than {} seconds.",
132                        path.display(),
133                        FILE_LOCK_WARNING_TIME.as_secs()
134                    );
135
136                    // Wrap the message in a warning diagnostic and print it.
137                    // Add a note to describe what is going on.
138                    Diagnostic::warning()
139                        .with_message(message)
140                        .with_notes(["This may be caused by another process holding or failing to release a lock on this file."])
141                        // Create a dummy empty source map here, since this diagnostic does not have any highlights.
142                        .print(&SourceMap::new())
143                        // If printing a diagnostic fails, we just crash. We should never have to deal with this failing 
144                        // in practice.
145                        .expect("codespan-reporting error");
146                }
147
148                // Handle any io errors locking the file by returning them.
149                Ok(ChannelMessage::LockingError(io_err)) => Err(io_err)?,
150
151                // Handle success by finishing adding the file to the FileMap.
152                Ok(ChannelMessage::FileLocked(file)) => {
153                    // The file is now locked, we can memmory map it and add it to the vec.
154                    // SAFETY: The file should be locked at this point so undefined behaviour from concurrent
155                    // modification is avoided.
156                    let mem_map: Mmap = unsafe {
157                        Mmap::map(&file)
158                            // Make sure we (at least try to) unlock the file if there's an issue memory mapping it.
159                            .inspect_err(|_| {
160                                FileExt::unlock(&file)
161                                    .map_err(|err| eprintln!("Error unlocking file: {:?}", err))
162                                    .ok();
163                            })
164                    }?;
165
166                    // Double check that the file is valid utf-8. If not, return an IO error.
167                    let raw_data: &[u8] = mem_map.as_ref();
168
169                    if let Err(utf8_error) = std::str::from_utf8(raw_data) {
170                        // The file is not valid for us so we should unlock it and return an error.
171                        FileExt::unlock(&file)
172                            .map_err(|err| eprintln!("Error unlocking file: {:?}", err))
173                            .ok();
174
175                        Err(io::Error::new(io::ErrorKind::InvalidData, utf8_error))?;
176                    }
177
178                    // If we get here, the file is valid UTF-8 -- put the memory mapped file in an Immutable string object.
179                    return Ok(Source::new(
180                        FileName::Real(path),
181                        ImmutableString::new_locked_file(file, mem_map),
182                    ));
183                }
184
185                Err(_) => unreachable!(
186                    "The reciever should never reach a state where both senders are closed."
187                ),
188            }
189        }
190    }
191
192    /// Read a file from the disk into a source. This reads the file, which may take longer than memory mapping it
193    /// as done in [Self::new_mapped_from_disk]. This does not require the same features and dependencies as memory
194    /// mapped operations though. This stores the whole file in memory, rather than mapping virtual memory to the disk.
195    /// That makes this less memory efficient than [Self::new_mapped_from_disk], which may be important on systems
196    /// where ram is constrained.
197    ///
198    /// Use this if the "file_memmap" is not available for some reason.
199    pub fn new_read_from_disk(path: PathBuf) -> io::Result<Self> {
200        // Open the file for reading.
201        let file: File = File::open(&path)?;
202        // Read the file to a string.
203        let content: String = io::read_to_string(&file)?;
204        // Turn that into a Source.
205        Ok(Self::new(
206            FileName::Real(path),
207            ImmutableString::new_owned(content.into_boxed_str()),
208        ))
209    }
210
211    /// Attempt to open a file using [Self::new_mapped_from_disk] if that feature is designated
212    /// as available at compile time. If that feature is not available at compile time, or if an [io::Error] is
213    /// returned from [memmap2], fallback to reading the file from disk using [Self::new_read_from_disk].
214    pub fn new_mapped_or_read(path: PathBuf) -> io::Result<Self> {
215        #[cfg(feature = "file_memmap")]
216        match Self::new_mapped_from_disk(path.clone()) {
217            ok @ Ok(_) => return ok,
218            Err(e) => {
219                eprintln!(
220                    "warn: attempted to map file at {} and got {e}, falling back to read",
221                    path.display()
222                );
223            }
224        };
225
226        Self::new_read_from_disk(path)
227    }
228
229    /// Get byte indices of where lines start in this [Source].
230    pub fn line_starts(&self) -> &[usize] {
231        self.line_starts.as_slice()
232    }
233
234    /// Get the number of lines in this [Source]. This is identical to [`Self::line_starts`] length.
235    pub fn count_lines(&self) -> usize {
236        self.line_starts.len()
237    }
238
239    /// Get the line index that a byte index is on in this [Source].
240    ///
241    /// If the byte index is greater than the length of the [Source] then the highest possible index will be returned.
242    pub fn line_index(&self, byte_index: usize) -> usize {
243        // Get a list of the byte indices that lines start on.
244        let line_starts: &[usize] = self.line_starts();
245
246        // We just want the exact line index if the byte index is at the beginning of a line, otherwise, give us the
247        // index of the line-start before it.
248        line_starts
249            .binary_search(&byte_index)
250            // Subtract 1 here to make sure we get the index of the line start before the byte index instead of
251            // after.
252            .unwrap_or_else(|not_found_index| not_found_index.saturating_sub(1))
253    }
254
255    /// Get a line of this [Source] as a [Fragment].
256    /// The returned [Fragment] will contain the line terminating characters at the end of it. If you don't want those,
257    /// use [Fragment::trim_end].
258    ///
259    /// *Note* that this uses `line_index` which is considered 0-indexed -- when displaying line numbers to the user,
260    /// remember to add 1.
261    ///
262    /// # Panics
263    /// - This will panic if you ask for a line index that's higher than or equal to the number returned
264    ///     by [`Self::count_lines`].
265    pub fn get_line(self: Arc<Source>, line_index: usize) -> Fragment {
266        if line_index >= self.count_lines() {
267            panic!("{} is greater than the number of lines in {}", line_index, self.name);
268        }
269
270        // Get the starting byte index of the line.
271        let start_byte_index: usize = self.line_starts[line_index];
272
273        // Get the ending byte index of the line / the starting index of the next line/the index of the end of the file.
274        let end_byte_index: usize = if line_index + 1 == self.count_lines() {
275            self.source.len()
276        } else {
277            self.line_starts[line_index + 1]
278        };
279
280        // Construct the resultant fragment.
281        let frag = Fragment {
282            source: Arc::clone(&self),
283            range: start_byte_index..end_byte_index,
284        };
285
286        // Debug assert that the fragment is valid. This should always be true but might be useful for testing.
287        debug_assert!(frag.is_valid());
288        // Return constructed fragment.
289        frag
290    }
291
292    /// Get an iterator over all the lines of this [Source]. This calls [Source::get_line] for each element of
293    /// the returned iterator.
294    ///
295    /// The returned [Fragment]s will contain the line terminating characters at the end of them. If you don't want
296    /// those, use [Iterator::map] and [Fragment::trim_end].
297    pub fn lines(self: SourceRef) -> impl Iterator<Item = Fragment> {
298        (0..self.count_lines()).map(move |line_index| self.clone().get_line(line_index))
299    }
300
301    /// Get the the source code stored.
302    pub const fn source(&self) -> &ImmutableString {
303        &self.source
304    }
305
306    /// Get the name of this [Source].
307    pub const fn name(&self) -> &FileName {
308        &self.name
309    }
310
311    /// Get the entire content of this [Source] as a [Fragment].
312    pub fn as_fragment(self: SourceRef) -> Fragment {
313        let len = self.source.len();
314        Fragment {
315            source: self,
316            range: 0..len,
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use std::{sync::mpsc, thread};
324
325    use crate::source_tracking::filename::FileName;
326
327    use super::Source;
328
329    #[test]
330    fn dozen_threads_dont_share_gids() {
331        let (tx, rx) = mpsc::channel();
332
333        for i in 0..12 {
334            let tx = tx.clone();
335            thread::spawn(move || {
336                let source = Source::new_from_string(FileName::None, format!("{i}"));
337                tx.send(source.id).unwrap();
338            });
339        }
340
341        let mut gids = (0..12).map(|_| rx.recv().unwrap()).collect::<Vec<_>>();
342
343        let original_len = gids.len();
344        println!("{gids:?}");
345        gids.sort();
346        gids.dedup();
347        let dedup_len = gids.len();
348
349        assert_eq!(original_len, dedup_len, "global ids are not duplicated");
350    }
351}