sycamore_reactive/
root.rs1use std::cell::{Cell, RefCell};
4
5use slotmap::{Key, SlotMap};
6use smallvec::SmallVec;
7
8use crate::*;
9
10pub(crate) struct Root {
19 pub tracker: RefCell<Option<DependencyTracker>>,
21 pub rev_sorted_buf: RefCell<Vec<NodeId>>,
24 pub current_node: Cell<NodeId>,
27 pub root_node: Cell<NodeId>,
29 pub nodes: RefCell<SlotMap<NodeId, ReactiveNode>>,
31 pub node_update_queue: RefCell<Vec<NodeId>>,
33 pub batching: Cell<bool>,
36}
37
38thread_local! {
39 static GLOBAL_ROOT: Cell<Option<&'static Root>> = const { Cell::new(None) };
41}
42
43impl Root {
44 #[cfg_attr(debug_assertions, track_caller)]
46 pub fn global() -> &'static Root {
47 GLOBAL_ROOT.with(|root| root.get()).expect("no root found")
48 }
49
50 pub fn set_global(root: Option<&'static Root>) -> Option<&'static Root> {
52 GLOBAL_ROOT.with(|r| r.replace(root))
53 }
54
55 pub fn new_static() -> &'static Self {
57 let this = Self {
58 tracker: RefCell::new(None),
59 rev_sorted_buf: RefCell::new(Vec::new()),
60 current_node: Cell::new(NodeId::null()),
61 root_node: Cell::new(NodeId::null()),
62 nodes: RefCell::new(SlotMap::default()),
63 node_update_queue: RefCell::new(Vec::new()),
64 batching: Cell::new(false),
65 };
66 let _ref = Box::leak(Box::new(this));
67 _ref.reinit();
68 _ref
69 }
70
71 pub fn reinit(&'static self) {
73 NodeHandle(self.root_node.get(), self).dispose();
75
76 let _ = self.tracker.take();
77 let _ = self.rev_sorted_buf.take();
78 let _ = self.node_update_queue.take();
79 let _ = self.current_node.take();
80 let _ = self.root_node.take();
81 let _ = self.nodes.take();
82 self.batching.set(false);
83
84 Root::set_global(Some(self));
86 let root_node = create_child_scope(|| {});
87 Root::set_global(None);
88 self.root_node.set(root_node.0);
89 self.current_node.set(root_node.0);
90 }
91
92 pub fn create_child_scope(&'static self, f: impl FnOnce()) -> NodeHandle {
94 let node = create_signal(()).id;
95 let prev = self.current_node.replace(node);
96 f();
97 self.current_node.set(prev);
98 NodeHandle(node, self)
99 }
100
101 pub fn tracked_scope<T>(&self, f: impl FnOnce() -> T) -> (T, DependencyTracker) {
104 let prev = self.tracker.replace(Some(DependencyTracker::default()));
105 let ret = f();
106 (ret, self.tracker.replace(prev).unwrap())
107 }
108
109 fn run_node_update(&'static self, current: NodeId) {
118 debug_assert_eq!(
119 self.nodes.borrow()[current].state,
120 NodeState::Dirty,
121 "should only update when dirty"
122 );
123 let dependencies = std::mem::take(&mut self.nodes.borrow_mut()[current].dependencies);
125 for dependency in dependencies {
126 self.nodes.borrow_mut()[dependency]
127 .dependents
128 .retain(|&id| id != current);
129 }
130 let mut nodes_mut = self.nodes.borrow_mut();
133 let mut callback = nodes_mut[current].callback.take().unwrap();
134 let mut value = nodes_mut[current].value.take().unwrap();
135 drop(nodes_mut); NodeHandle(current, self).dispose_children(); let prev = self.current_node.replace(current);
140 let (changed, tracker) = self.tracked_scope(|| callback(&mut value));
141 self.current_node.set(prev);
142
143 tracker.create_dependency_link(self, current);
144
145 let mut nodes_mut = self.nodes.borrow_mut();
146 nodes_mut[current].callback = Some(callback); nodes_mut[current].value = Some(value);
148
149 nodes_mut[current].state = NodeState::Clean;
151 drop(nodes_mut);
152
153 if changed {
154 self.mark_dependents_dirty(current);
155 }
156 }
157
158 fn mark_dependents_dirty(&self, current: NodeId) {
160 let mut nodes_mut = self.nodes.borrow_mut();
161 let dependents = std::mem::take(&mut nodes_mut[current].dependents);
162 for &dependent in &dependents {
163 if let Some(dependent) = nodes_mut.get_mut(dependent) {
164 dependent.state = NodeState::Dirty;
165 }
166 }
167 nodes_mut[current].dependents = dependents;
168 }
169
170 fn propagate_node_updates(&'static self, start_nodes: &[NodeId]) {
176 let mut rev_sorted = Vec::new();
178 let mut rev_sorted_buf = self.rev_sorted_buf.try_borrow_mut();
179 let rev_sorted = if let Ok(rev_sorted_buf) = rev_sorted_buf.as_mut() {
180 rev_sorted_buf.clear();
181 rev_sorted_buf
182 } else {
183 &mut rev_sorted
184 };
185
186 for &node in start_nodes {
188 Self::dfs(node, &mut self.nodes.borrow_mut(), rev_sorted);
189 self.mark_dependents_dirty(node);
190 }
191
192 for &node in rev_sorted.iter().rev() {
193 let mut nodes_mut = self.nodes.borrow_mut();
194 if nodes_mut.get(node).is_none() {
196 continue;
197 }
198 let node_state = &mut nodes_mut[node];
199 node_state.mark = Mark::None; if nodes_mut[node].state == NodeState::Dirty {
203 drop(nodes_mut); self.run_node_update(node)
205 };
206 }
207 }
208
209 pub fn propagate_updates(&'static self, start_node: NodeId) {
214 if self.batching.get() {
215 self.node_update_queue.borrow_mut().push(start_node);
216 } else {
217 let prev = Root::set_global(Some(self));
219 self.propagate_node_updates(&[start_node]);
221 Root::set_global(prev);
222 }
223 }
224
225 fn dfs(current_id: NodeId, nodes: &mut SlotMap<NodeId, ReactiveNode>, buf: &mut Vec<NodeId>) {
227 let Some(current) = nodes.get_mut(current_id) else {
228 return;
230 };
231
232 match current.mark {
233 Mark::Temp => panic!("cyclic reactive dependency"),
234 Mark::Permanent => return,
235 Mark::None => {}
236 }
237 current.mark = Mark::Temp;
238
239 let children = std::mem::take(&mut current.dependents);
241 for child in &children {
242 Self::dfs(*child, nodes, buf);
243 }
244 nodes[current_id].dependents = children;
245
246 nodes[current_id].mark = Mark::Permanent;
247 buf.push(current_id);
248 }
249
250 fn start_batch(&self) {
252 self.batching.set(true);
253 }
254
255 fn end_batch(&'static self) {
257 self.batching.set(false);
258 let nodes = self.node_update_queue.take();
259 self.propagate_node_updates(&nodes);
260 }
261}
262
263#[derive(Clone, Copy)]
267pub struct RootHandle {
268 _ref: &'static Root,
269}
270
271impl RootHandle {
272 pub fn dispose(&self) {
274 self._ref.reinit();
275 }
276
277 pub fn run_in<T>(&self, f: impl FnOnce() -> T) -> T {
279 let prev = Root::set_global(Some(self._ref));
280 let ret = f();
281 Root::set_global(prev);
282 ret
283 }
284}
285
286#[derive(Default)]
288pub(crate) struct DependencyTracker {
289 pub dependencies: SmallVec<[NodeId; 1]>,
291}
292
293impl DependencyTracker {
294 pub fn create_dependency_link(self, root: &Root, dependent: NodeId) {
297 for node in &self.dependencies {
298 root.nodes.borrow_mut()[*node].dependents.push(dependent);
299 }
300 root.nodes.borrow_mut()[dependent].dependencies = self.dependencies;
302 }
303}
304
305#[must_use = "root should be disposed"]
321pub fn create_root(f: impl FnOnce()) -> RootHandle {
322 let _ref = Root::new_static();
323 #[cfg(not(target_arch = "wasm32"))]
324 {
325 #[allow(dead_code)]
328 struct UnsafeSendPtr<T>(*const T);
329 unsafe impl<T> Send for UnsafeSendPtr<T> {}
331
332 static KEEP_ALIVE: std::sync::Mutex<Vec<UnsafeSendPtr<Root>>> =
335 std::sync::Mutex::new(Vec::new());
336 KEEP_ALIVE
337 .lock()
338 .unwrap()
339 .push(UnsafeSendPtr(_ref as *const Root));
340 }
341
342 Root::set_global(Some(_ref));
343 NodeHandle(_ref.root_node.get(), _ref).run_in(f);
344 Root::set_global(None);
345 RootHandle { _ref }
346}
347
348#[cfg_attr(debug_assertions, track_caller)]
352pub fn create_child_scope(f: impl FnOnce()) -> NodeHandle {
353 Root::global().create_child_scope(f)
354}
355
356pub fn on_cleanup(f: impl FnOnce() + 'static) {
371 let root = Root::global();
372 if !root.current_node.get().is_null() {
373 root.nodes.borrow_mut()[root.current_node.get()]
374 .cleanups
375 .push(Box::new(f));
376 }
377}
378
379pub fn batch<T>(f: impl FnOnce() -> T) -> T {
397 let root = Root::global();
398 root.start_batch();
399 let ret = f();
400 root.end_batch();
401 ret
402}
403
404pub fn untrack<T>(f: impl FnOnce() -> T) -> T {
423 untrack_in_scope(f, Root::global())
424}
425
426pub(crate) fn untrack_in_scope<T>(f: impl FnOnce() -> T, root: &'static Root) -> T {
428 let prev = root.tracker.replace(None);
429 let ret = f();
430 root.tracker.replace(prev);
431 ret
432}
433
434pub fn use_current_scope() -> NodeHandle {
436 let root = Root::global();
437 NodeHandle(root.current_node.get(), root)
438}
439
440pub fn use_global_scope() -> NodeHandle {
442 let root = Root::global();
443 NodeHandle(root.root_node.get(), root)
444}
445
446#[cfg(test)]
447mod tests {
448 use crate::*;
449
450 #[test]
451 fn cleanup() {
452 let _ = create_root(|| {
453 let cleanup_called = create_signal(false);
454 let scope = create_child_scope(|| {
455 on_cleanup(move || {
456 cleanup_called.set(true);
457 });
458 });
459 assert!(!cleanup_called.get());
460 scope.dispose();
461 assert!(cleanup_called.get());
462 });
463 }
464
465 #[test]
466 fn cleanup_in_effect() {
467 let _ = create_root(|| {
468 let trigger = create_signal(());
469
470 let counter = create_signal(0);
471
472 create_effect(move || {
473 trigger.track();
474
475 on_cleanup(move || {
476 counter.set(counter.get() + 1);
477 });
478 });
479
480 assert_eq!(counter.get(), 0);
481
482 trigger.set(());
483 assert_eq!(counter.get(), 1);
484
485 trigger.set(());
486 assert_eq!(counter.get(), 2);
487 });
488 }
489
490 #[test]
491 fn cleanup_is_untracked() {
492 let _ = create_root(|| {
493 let trigger = create_signal(());
494
495 let counter = create_signal(0);
496
497 create_effect(move || {
498 counter.set(counter.get_untracked() + 1);
499
500 on_cleanup(move || {
501 trigger.track(); });
503 });
504
505 assert_eq!(counter.get(), 1);
506
507 trigger.set(());
508 assert_eq!(counter.get(), 1);
509 });
510 }
511
512 #[test]
513 fn batch_memo() {
514 let _ = create_root(|| {
515 let state = create_signal(1);
516 let double = create_memo(move || state.get() * 2);
517 batch(move || {
518 state.set(2);
519 assert_eq!(double.get(), 2);
520 });
521 assert_eq!(double.get(), 4);
522 });
523 }
524
525 #[test]
526 fn batch_updates_effects_at_end() {
527 let _ = create_root(|| {
528 let state1 = create_signal(1);
529 let state2 = create_signal(2);
530 let counter = create_signal(0);
531 create_effect(move || {
532 counter.set(counter.get_untracked() + 1);
533 let _ = state1.get() + state2.get();
534 });
535 assert_eq!(counter.get(), 1);
536 state1.set(2);
537 state2.set(3);
538 assert_eq!(counter.get(), 3);
539 batch(move || {
540 state1.set(3);
541 assert_eq!(counter.get(), 3);
542 state2.set(4);
543 assert_eq!(counter.get(), 3);
544 });
545 assert_eq!(counter.get(), 4);
546 });
547 }
548}