sycamore_reactive/
root.rs

1//! [`Root`] and [`Scope`].
2
3use std::cell::{Cell, RefCell};
4
5use slotmap::{Key, SlotMap};
6use smallvec::SmallVec;
7
8use crate::*;
9
10/// The struct managing the state of the reactive system. Only one should be created per running
11/// app.
12///
13/// Often times, this is intended to be leaked to be able to get a `&'static Root`. However, the
14/// `Root` is also `dispose`-able, meaning that any resources allocated in this `Root` will get
15/// deallocated. Therefore in practice, there should be no memory leak at all except for the `Root`
16/// itself. Finally, the `Root` is expected to live for the whole duration of the app so this is
17/// not a problem.
18pub(crate) struct Root {
19    /// If this is `Some`, that means we are tracking signal accesses.
20    pub tracker: RefCell<Option<DependencyTracker>>,
21    /// A temporary buffer used in `propagate_updates` to prevent allocating a new Vec every time
22    /// it is called.
23    pub rev_sorted_buf: RefCell<Vec<NodeId>>,
24    /// The current node that owns everything created in its scope.
25    /// If we are at the top-level, then this is the "null" key.
26    pub current_node: Cell<NodeId>,
27    /// The root node of the reactive graph.
28    pub root_node: Cell<NodeId>,
29    /// All the nodes created in this `Root`.
30    pub nodes: RefCell<SlotMap<NodeId, ReactiveNode>>,
31    /// A list of signals who need their values to be propagated after the batch is over.
32    pub node_update_queue: RefCell<Vec<NodeId>>,
33    /// Whether we are currently batching signal updatse. If this is true, we do not run
34    /// `effect_queue` and instead wait until the end of the batch.
35    pub batching: Cell<bool>,
36}
37
38thread_local! {
39    /// The current reactive root.
40    static GLOBAL_ROOT: Cell<Option<&'static Root>> = const { Cell::new(None) };
41}
42
43impl Root {
44    /// Get the current reactive root. Panics if no root is found.
45    #[cfg_attr(debug_assertions, track_caller)]
46    pub fn global() -> &'static Root {
47        GLOBAL_ROOT.with(|root| root.get()).expect("no root found")
48    }
49
50    /// Sets the current reactive root. Returns the previous root.
51    pub fn set_global(root: Option<&'static Root>) -> Option<&'static Root> {
52        GLOBAL_ROOT.with(|r| r.replace(root))
53    }
54
55    /// Create a new reactive root. This root is leaked and so lives until the end of the program.
56    pub fn new_static() -> &'static Self {
57        let this = Self {
58            tracker: RefCell::new(None),
59            rev_sorted_buf: RefCell::new(Vec::new()),
60            current_node: Cell::new(NodeId::null()),
61            root_node: Cell::new(NodeId::null()),
62            nodes: RefCell::new(SlotMap::default()),
63            node_update_queue: RefCell::new(Vec::new()),
64            batching: Cell::new(false),
65        };
66        let _ref = Box::leak(Box::new(this));
67        _ref.reinit();
68        _ref
69    }
70
71    /// Disposes of all the resources held on by this root and resets the state.
72    pub fn reinit(&'static self) {
73        // Dispose the root node.
74        NodeHandle(self.root_node.get(), self).dispose();
75
76        let _ = self.tracker.take();
77        let _ = self.rev_sorted_buf.take();
78        let _ = self.node_update_queue.take();
79        let _ = self.current_node.take();
80        let _ = self.root_node.take();
81        let _ = self.nodes.take();
82        self.batching.set(false);
83
84        // Create a new root node.
85        Root::set_global(Some(self));
86        let root_node = create_child_scope(|| {});
87        Root::set_global(None);
88        self.root_node.set(root_node.0);
89        self.current_node.set(root_node.0);
90    }
91
92    /// Create a new child scope. Implementation detail for [`create_child_scope`].
93    pub fn create_child_scope(&'static self, f: impl FnOnce()) -> NodeHandle {
94        let node = create_signal(()).id;
95        let prev = self.current_node.replace(node);
96        f();
97        self.current_node.set(prev);
98        NodeHandle(node, self)
99    }
100
101    /// Run the provided closure in a tracked scope. This will detect all the signals that are
102    /// accessed and track them in a dependency list.
103    pub fn tracked_scope<T>(&self, f: impl FnOnce() -> T) -> (T, DependencyTracker) {
104        let prev = self.tracker.replace(Some(DependencyTracker::default()));
105        let ret = f();
106        (ret, self.tracker.replace(prev).unwrap())
107    }
108
109    /// Run the update callback of the signal, also recreating any dependencies found by
110    /// tracking signal accesses inside the function.
111    ///
112    /// Also marks all the dependencies as dirty and marks the current node as clean.
113    ///
114    /// # Params
115    /// * `root` - The reactive root.
116    /// * `id` - The id associated with the reactive node. `SignalId` inside the state itself.
117    fn run_node_update(&'static self, current: NodeId) {
118        debug_assert_eq!(
119            self.nodes.borrow()[current].state,
120            NodeState::Dirty,
121            "should only update when dirty"
122        );
123        // Remove old dependency links.
124        let dependencies = std::mem::take(&mut self.nodes.borrow_mut()[current].dependencies);
125        for dependency in dependencies {
126            self.nodes.borrow_mut()[dependency]
127                .dependents
128                .retain(|&id| id != current);
129        }
130        // We take the callback out because that requires a mut ref and we cannot hold that while
131        // running update itself.
132        let mut nodes_mut = self.nodes.borrow_mut();
133        let mut callback = nodes_mut[current].callback.take().unwrap();
134        let mut value = nodes_mut[current].value.take().unwrap();
135        drop(nodes_mut); // End RefMut borrow.
136
137        NodeHandle(current, self).dispose_children(); // Destroy anything created in a previous update.
138
139        let prev = self.current_node.replace(current);
140        let (changed, tracker) = self.tracked_scope(|| callback(&mut value));
141        self.current_node.set(prev);
142
143        tracker.create_dependency_link(self, current);
144
145        let mut nodes_mut = self.nodes.borrow_mut();
146        nodes_mut[current].callback = Some(callback); // Put the callback back in.
147        nodes_mut[current].value = Some(value);
148
149        // Mark this node as clean.
150        nodes_mut[current].state = NodeState::Clean;
151        drop(nodes_mut);
152
153        if changed {
154            self.mark_dependents_dirty(current);
155        }
156    }
157
158    // Mark any dependent node of the current node as dirty.
159    fn mark_dependents_dirty(&self, current: NodeId) {
160        let mut nodes_mut = self.nodes.borrow_mut();
161        let dependents = std::mem::take(&mut nodes_mut[current].dependents);
162        for &dependent in &dependents {
163            if let Some(dependent) = nodes_mut.get_mut(dependent) {
164                dependent.state = NodeState::Dirty;
165            }
166        }
167        nodes_mut[current].dependents = dependents;
168    }
169
170    /// If there are no cyclic dependencies, then the reactive graph is a DAG (Directed Acylic
171    /// Graph). We can therefore use DFS to get a topological sorting of all the reactive nodes.
172    ///
173    /// We then go through every node in this topological sorting and update only those nodes which
174    /// have dependencies that were updated.
175    fn propagate_node_updates(&'static self, start_nodes: &[NodeId]) {
176        // Try to reuse the shared buffer if possible.
177        let mut rev_sorted = Vec::new();
178        let mut rev_sorted_buf = self.rev_sorted_buf.try_borrow_mut();
179        let rev_sorted = if let Ok(rev_sorted_buf) = rev_sorted_buf.as_mut() {
180            rev_sorted_buf.clear();
181            rev_sorted_buf
182        } else {
183            &mut rev_sorted
184        };
185
186        // Traverse reactive graph.
187        for &node in start_nodes {
188            Self::dfs(node, &mut self.nodes.borrow_mut(), rev_sorted);
189            self.mark_dependents_dirty(node);
190        }
191
192        for &node in rev_sorted.iter().rev() {
193            let mut nodes_mut = self.nodes.borrow_mut();
194            // Only run if node is still alive.
195            if nodes_mut.get(node).is_none() {
196                continue;
197            }
198            let node_state = &mut nodes_mut[node];
199            node_state.mark = Mark::None; // Reset value.
200
201            // Check if this node needs to be updated.
202            if nodes_mut[node].state == NodeState::Dirty {
203                drop(nodes_mut); // End RefMut borrow.
204                self.run_node_update(node)
205            };
206        }
207    }
208
209    /// Call this if `start_node` has been updated manually. This will automatically update all
210    /// signals that depend on `start_node`.
211    ///
212    /// If we are currently batching, defers updating the signal until the end of the batch.
213    pub fn propagate_updates(&'static self, start_node: NodeId) {
214        if self.batching.get() {
215            self.node_update_queue.borrow_mut().push(start_node);
216        } else {
217            // Set the global root.
218            let prev = Root::set_global(Some(self));
219            // Propagate any signal updates.
220            self.propagate_node_updates(&[start_node]);
221            Root::set_global(prev);
222        }
223    }
224
225    /// Run depth-first-search on the reactive graph starting at `current`.
226    fn dfs(current_id: NodeId, nodes: &mut SlotMap<NodeId, ReactiveNode>, buf: &mut Vec<NodeId>) {
227        let Some(current) = nodes.get_mut(current_id) else {
228            // If signal is dead, don't even visit it.
229            return;
230        };
231
232        match current.mark {
233            Mark::Temp => panic!("cyclic reactive dependency"),
234            Mark::Permanent => return,
235            Mark::None => {}
236        }
237        current.mark = Mark::Temp;
238
239        // Take the `dependents` field out temporarily to avoid borrow checker.
240        let children = std::mem::take(&mut current.dependents);
241        for child in &children {
242            Self::dfs(*child, nodes, buf);
243        }
244        nodes[current_id].dependents = children;
245
246        nodes[current_id].mark = Mark::Permanent;
247        buf.push(current_id);
248    }
249
250    /// Sets the batch flag to `true`.
251    fn start_batch(&self) {
252        self.batching.set(true);
253    }
254
255    /// Sets the batch flag to `false` and run all the queued effects.
256    fn end_batch(&'static self) {
257        self.batching.set(false);
258        let nodes = self.node_update_queue.take();
259        self.propagate_node_updates(&nodes);
260    }
261}
262
263/// A handle to a root. This lets you reinitialize or dispose the root for resource cleanup.
264///
265/// This is generally obtained from [`create_root`].
266#[derive(Clone, Copy)]
267pub struct RootHandle {
268    _ref: &'static Root,
269}
270
271impl RootHandle {
272    /// Destroy everything that was created in this scope.
273    pub fn dispose(&self) {
274        self._ref.reinit();
275    }
276
277    /// Runs the closure in the current scope of the root.
278    pub fn run_in<T>(&self, f: impl FnOnce() -> T) -> T {
279        let prev = Root::set_global(Some(self._ref));
280        let ret = f();
281        Root::set_global(prev);
282        ret
283    }
284}
285
286/// Tracks nodes that are accessed inside a reactive scope.
287#[derive(Default)]
288pub(crate) struct DependencyTracker {
289    /// A list of reactive nodes that were accessed.
290    pub dependencies: SmallVec<[NodeId; 1]>,
291}
292
293impl DependencyTracker {
294    /// Sets the `dependents` field for all the nodes that have been tracked and updates
295    /// `dependencies` of the `dependent`.
296    pub fn create_dependency_link(self, root: &Root, dependent: NodeId) {
297        for node in &self.dependencies {
298            root.nodes.borrow_mut()[*node].dependents.push(dependent);
299        }
300        // Set the signal dependencies so that it is updated automatically.
301        root.nodes.borrow_mut()[dependent].dependencies = self.dependencies;
302    }
303}
304
305/// Creates a new reactive root with a top-level reactive node. The returned [`RootHandle`] can be
306/// used to [`dispose`](RootHandle::dispose) the root.
307///
308/// # Example
309/// ```rust
310/// # use sycamore_reactive::*;
311///
312/// create_root(|| {
313///     let signal = create_signal(123);
314///
315///     let child_scope = create_child_scope(move || {
316///         // ...
317///     });
318/// });
319/// ```
320#[must_use = "root should be disposed"]
321pub fn create_root(f: impl FnOnce()) -> RootHandle {
322    let _ref = Root::new_static();
323    #[cfg(not(target_arch = "wasm32"))]
324    {
325        /// An unsafe wrapper around a raw pointer which we promise to never touch, effectively
326        /// making it thread-safe.
327        #[allow(dead_code)]
328        struct UnsafeSendPtr<T>(*const T);
329        /// We never ever touch the pointer inside so surely this is safe!
330        unsafe impl<T> Send for UnsafeSendPtr<T> {}
331
332        /// A static variable to keep on holding to the allocated `Root`s to prevent Miri and
333        /// Valgrind from complaining.
334        static KEEP_ALIVE: std::sync::Mutex<Vec<UnsafeSendPtr<Root>>> =
335            std::sync::Mutex::new(Vec::new());
336        KEEP_ALIVE
337            .lock()
338            .unwrap()
339            .push(UnsafeSendPtr(_ref as *const Root));
340    }
341
342    Root::set_global(Some(_ref));
343    NodeHandle(_ref.root_node.get(), _ref).run_in(f);
344    Root::set_global(None);
345    RootHandle { _ref }
346}
347
348/// Create a child scope.
349///
350/// Returns the created [`NodeHandle`] which can be used to dispose it.
351#[cfg_attr(debug_assertions, track_caller)]
352pub fn create_child_scope(f: impl FnOnce()) -> NodeHandle {
353    Root::global().create_child_scope(f)
354}
355
356/// Adds a callback that is called when the scope is destroyed.
357///
358/// # Example
359/// ```rust
360/// # use sycamore_reactive::*;
361/// # create_root(|| {
362/// let child_scope = create_child_scope(|| {
363///     on_cleanup(|| {
364///         println!("Child scope is being dropped");
365///     });
366/// });
367/// child_scope.dispose(); // Executes the on_cleanup callback.
368/// # });
369/// ```
370pub fn on_cleanup(f: impl FnOnce() + 'static) {
371    let root = Root::global();
372    if !root.current_node.get().is_null() {
373        root.nodes.borrow_mut()[root.current_node.get()]
374            .cleanups
375            .push(Box::new(f));
376    }
377}
378
379/// Batch updates from related signals together and only run memos and effects at the end of the
380/// scope.
381///
382/// # Example
383///
384/// ```
385/// # use sycamore_reactive::*;
386/// # let _ = create_root(|| {
387/// let state = create_signal(1);
388/// let double = create_memo(move || state.get() * 2);
389/// batch(move || {
390///     state.set(2);
391///     assert_eq!(double.get(), 2);
392/// });
393/// assert_eq!(double.get(), 4);
394/// # });
395/// ```
396pub fn batch<T>(f: impl FnOnce() -> T) -> T {
397    let root = Root::global();
398    root.start_batch();
399    let ret = f();
400    root.end_batch();
401    ret
402}
403
404/// Run the passed closure inside an untracked dependency scope.
405///
406/// See also [`ReadSignal::get_untracked`].
407///
408/// # Example
409///
410/// ```
411/// # use sycamore_reactive::*;
412/// # create_root(|| {
413/// let state = create_signal(1);
414/// let double = create_memo(move || untrack(|| state.get() * 2));
415/// assert_eq!(double.get(), 2);
416///
417/// state.set(2);
418/// // double value should still be old value because state was untracked
419/// assert_eq!(double.get(), 2);
420/// # });
421/// ```
422pub fn untrack<T>(f: impl FnOnce() -> T) -> T {
423    untrack_in_scope(f, Root::global())
424}
425
426/// Same as [`untrack`] but for a specific [`Root`].
427pub(crate) fn untrack_in_scope<T>(f: impl FnOnce() -> T, root: &'static Root) -> T {
428    let prev = root.tracker.replace(None);
429    let ret = f();
430    root.tracker.replace(prev);
431    ret
432}
433
434/// Get a handle to the current reactive scope.
435pub fn use_current_scope() -> NodeHandle {
436    let root = Root::global();
437    NodeHandle(root.current_node.get(), root)
438}
439
440/// Get a handle to the root reactive scope.
441pub fn use_global_scope() -> NodeHandle {
442    let root = Root::global();
443    NodeHandle(root.root_node.get(), root)
444}
445
446#[cfg(test)]
447mod tests {
448    use crate::*;
449
450    #[test]
451    fn cleanup() {
452        let _ = create_root(|| {
453            let cleanup_called = create_signal(false);
454            let scope = create_child_scope(|| {
455                on_cleanup(move || {
456                    cleanup_called.set(true);
457                });
458            });
459            assert!(!cleanup_called.get());
460            scope.dispose();
461            assert!(cleanup_called.get());
462        });
463    }
464
465    #[test]
466    fn cleanup_in_effect() {
467        let _ = create_root(|| {
468            let trigger = create_signal(());
469
470            let counter = create_signal(0);
471
472            create_effect(move || {
473                trigger.track();
474
475                on_cleanup(move || {
476                    counter.set(counter.get() + 1);
477                });
478            });
479
480            assert_eq!(counter.get(), 0);
481
482            trigger.set(());
483            assert_eq!(counter.get(), 1);
484
485            trigger.set(());
486            assert_eq!(counter.get(), 2);
487        });
488    }
489
490    #[test]
491    fn cleanup_is_untracked() {
492        let _ = create_root(|| {
493            let trigger = create_signal(());
494
495            let counter = create_signal(0);
496
497            create_effect(move || {
498                counter.set(counter.get_untracked() + 1);
499
500                on_cleanup(move || {
501                    trigger.track(); // trigger should not be tracked
502                });
503            });
504
505            assert_eq!(counter.get(), 1);
506
507            trigger.set(());
508            assert_eq!(counter.get(), 1);
509        });
510    }
511
512    #[test]
513    fn batch_memo() {
514        let _ = create_root(|| {
515            let state = create_signal(1);
516            let double = create_memo(move || state.get() * 2);
517            batch(move || {
518                state.set(2);
519                assert_eq!(double.get(), 2);
520            });
521            assert_eq!(double.get(), 4);
522        });
523    }
524
525    #[test]
526    fn batch_updates_effects_at_end() {
527        let _ = create_root(|| {
528            let state1 = create_signal(1);
529            let state2 = create_signal(2);
530            let counter = create_signal(0);
531            create_effect(move || {
532                counter.set(counter.get_untracked() + 1);
533                let _ = state1.get() + state2.get();
534            });
535            assert_eq!(counter.get(), 1);
536            state1.set(2);
537            state2.set(3);
538            assert_eq!(counter.get(), 3);
539            batch(move || {
540                state1.set(3);
541                assert_eq!(counter.get(), 3);
542                state2.set(4);
543                assert_eq!(counter.get(), 3);
544            });
545            assert_eq!(counter.get(), 4);
546        });
547    }
548}