1616
1717class Trap (BaseException ): pass
1818class CoreWebAssemblyException (BaseException ): pass
19+ class ThreadExit (BaseException ): pass
1920
2021def trap ():
2122 raise Trap ()
@@ -294,7 +295,7 @@ class ComponentInstance:
294295 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
295296 threads : Table [Thread ]
296297 may_leave : bool
297- may_block : bool
298+ sync_before_return : bool
298299 backpressure : int
299300 exclusive : Optional [Task ]
300301 num_waiting_to_enter : int
@@ -306,11 +307,17 @@ def __init__(self, store, parent = None):
306307 self .handles = Table ()
307308 self .threads = Table ()
308309 self .may_leave = True
309- self .may_block = True
310+ self .sync_before_return = False
310311 self .backpressure = 0
311312 self .exclusive = None
312313 self .num_waiting_to_enter = 0
313314
315+ def ready_threads (self ) -> list [Thread ]:
316+ return [t for t in self .threads .array if t and t .waiting () and t .ready ()]
317+
318+ def may_block (self ):
319+ return not self .sync_before_return or len (self .ready_threads ()) > 0
320+
314321 def reflexive_ancestors (self ) -> set [ComponentInstance ]:
315322 s = set ()
316323 inst = self
@@ -487,7 +494,10 @@ def ready(self):
487494 def __init__ (self , task , thread_func ):
488495 def cont_func (cancelled ):
489496 assert (self .running () and not cancelled )
490- thread_func ()
497+ try :
498+ thread_func ()
499+ except ThreadExit :
500+ pass
491501 return None
492502 self .cont = cont_new (cont_func )
493503 self .ready_func = None
@@ -497,7 +507,7 @@ def cont_func(cancelled):
497507 self .storage = [0 ,0 ]
498508 assert (self .suspended ())
499509
500- def resume_later (self ):
510+ def unsuspend (self ):
501511 assert (self .suspended ())
502512 self .ready_func = lambda : True
503513 self .task .inst .store .waiting .append (self )
@@ -507,18 +517,25 @@ def resume(self, cancelled):
507517 assert (not self .running () and (self .cancellable or not cancelled ))
508518 if self .waiting ():
509519 assert (cancelled or self .ready ())
510- self .ready_func = None
511- self .task .inst .store .waiting .remove (self )
520+ self .stop_waiting ()
512521 thread = self
513522 while thread is not None :
514523 cont = thread .cont
515524 thread .cont = None
516525 (thread .cont , switch_to ) = resume (cont , cancelled , thread )
526+ if switch_to is None and self .task .inst .sync_before_return :
527+ switch_to = random .choice (self .task .inst .ready_threads ())
528+ switch_to .stop_waiting ()
517529 thread = switch_to
518530 cancelled = Cancelled .FALSE
519531
532+ def stop_waiting (self ):
533+ assert (self .waiting ())
534+ self .ready_func = None
535+ self .task .inst .store .waiting .remove (self )
536+
520537 def suspend (self , cancellable ) -> Cancelled :
521- assert (self .running () and self . task . inst . may_block )
538+ assert (self .running ())
522539 if self .task .deliver_pending_cancel (cancellable ):
523540 return Cancelled .TRUE
524541 self .cancellable = cancellable
@@ -527,7 +544,7 @@ def suspend(self, cancellable) -> Cancelled:
527544 return cancelled
528545
529546 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
530- assert (self .running () and self . task . inst . may_block )
547+ assert (self .running ())
531548 if self .task .deliver_pending_cancel (cancellable ):
532549 return Cancelled .TRUE
533550 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -536,18 +553,10 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
536553 self .task .inst .store .waiting .append (self )
537554 return self .suspend (cancellable )
538555
539- def yield_until (self , ready_func , cancellable ) -> Cancelled :
540- assert (self .running ())
541- if self .task .inst .may_block :
542- return self .wait_until (ready_func , cancellable )
543- else :
544- assert (ready_func ())
545- return Cancelled .FALSE
546-
547556 def yield_ (self , cancellable ) -> Cancelled :
548- return self .yield_until (lambda : True , cancellable )
557+ return self .wait_until (lambda : True , cancellable )
549558
550- def switch_to (self , cancellable , other : Thread ) -> Cancelled :
559+ def suspend_to_suspended (self , cancellable , other : Thread ) -> ResumeArg :
551560 assert (self .running () and other .suspended ())
552561 if self .task .deliver_pending_cancel (cancellable ):
553562 return Cancelled .TRUE
@@ -556,11 +565,31 @@ def switch_to(self, cancellable, other: Thread) -> Cancelled:
556565 assert (self .running () and (cancellable or not cancelled ))
557566 return cancelled
558567
559- def yield_to (self , cancellable , other : Thread ) -> Cancelled :
568+ def yield_to_suspended (self , cancellable , other : Thread ) -> ResumeArg :
560569 assert (self .running () and other .suspended ())
561570 self .ready_func = lambda : True
562571 self .task .inst .store .waiting .append (self )
563- return self .switch_to (cancellable , other )
572+ return self .suspend_to_suspended (cancellable , other )
573+
574+ def suspend_then_promote (self , cancellable , other : Thread ) -> ResumeArg :
575+ assert (self .running ())
576+ if other .waiting () and other .ready ():
577+ other .stop_waiting ()
578+ return self .suspend_to_suspended (cancellable , other )
579+ else :
580+ return self .suspend (cancellable )
581+
582+ def yield_then_promote (self , cancellable , other : Thread ) -> ResumeArg :
583+ assert (self .running ())
584+ if other .waiting () and other .ready ():
585+ other .stop_waiting ()
586+ return self .yield_to_suspended (cancellable , other )
587+ else :
588+ return self .yield_ (cancellable )
589+
590+ def exit (self ):
591+ assert (self .running () and self .task .inst .may_block ())
592+ raise ThreadExit ()
564593
565594#### Waitable State
566595
@@ -701,8 +730,8 @@ def has_backpressure():
701730 assert (self .inst .exclusive is None )
702731 self .inst .exclusive = self
703732 else :
704- assert (self .inst .may_block )
705- self .inst .may_block = False
733+ assert (not self .inst .sync_before_return )
734+ self .inst .sync_before_return = True
706735 self .register_thread (thread )
707736 return True
708737
@@ -753,8 +782,8 @@ def return_(self, result):
753782 trap_if (self .state == Task .State .RESOLVED )
754783 trap_if (self .num_borrows > 0 )
755784 if not self .ft .async_ :
756- assert (not self .inst .may_block )
757- self .inst .may_block = True
785+ assert (self .inst .sync_before_return )
786+ self .inst .sync_before_return = False
758787 assert (result is not None )
759788 self .on_resolve (result )
760789 self .state = Task .State .RESOLVED
@@ -2091,12 +2120,12 @@ def thread_func():
20912120 inst .exclusive = None
20922121 match code :
20932122 case CallbackCode .YIELD :
2094- if thread .yield_until (lambda : not inst .exclusive , cancellable = True ) == Cancelled .TRUE :
2123+ if thread .wait_until (lambda : not inst .exclusive , cancellable = True ) == Cancelled .TRUE :
20952124 event = (EventCode .TASK_CANCELLED , 0 , 0 )
20962125 else :
20972126 event = (EventCode .NONE , 0 , 0 )
20982127 case CallbackCode .WAIT :
2099- trap_if (not inst .may_block )
2128+ trap_if (not inst .may_block () )
21002129 wset = inst .handles .get (si )
21012130 trap_if (not isinstance (wset , WaitableSet ))
21022131 event = wset .wait_until (lambda : not inst .exclusive , cancellable = True )
@@ -2140,7 +2169,7 @@ def call_and_trap_on_throw(callee, args):
21402169def canon_lower (opts , ft , callee : FuncInst , flat_args ):
21412170 thread = current_thread ()
21422171 trap_if (not thread .task .inst .may_leave )
2143- trap_if (not thread .task .inst .may_block and ft .async_ and not opts .async_ )
2172+ trap_if (not thread .task .inst .may_block () and ft .async_ and not opts .async_ )
21442173
21452174 subtask = Subtask ()
21462175 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2328,7 +2357,7 @@ def canon_waitable_set_new():
23282357def canon_waitable_set_wait (cancellable , mem , si , ptr ):
23292358 inst = current_thread ().task .inst
23302359 trap_if (not inst .may_leave )
2331- trap_if (not inst .may_block )
2360+ trap_if (not inst .may_block () )
23322361 wset = inst .handles .get (si )
23332362 trap_if (not isinstance (wset , WaitableSet ))
23342363 event = wset .wait (cancellable )
@@ -2383,7 +2412,7 @@ def canon_waitable_join(wi, si):
23832412def canon_subtask_cancel (async_ , i ):
23842413 thread = current_thread ()
23852414 trap_if (not thread .task .inst .may_leave )
2386- trap_if (not thread .task .inst .may_block and not async_ )
2415+ trap_if (not thread .task .inst .may_block () and not async_ )
23872416 subtask = thread .task .inst .handles .get (i )
23882417 trap_if (not isinstance (subtask , Subtask ))
23892418 trap_if (subtask .resolve_delivered ())
@@ -2444,7 +2473,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24442473def stream_copy (EndT , BufferT , event_code , stream_t , opts , i , ptr , n ):
24452474 thread = current_thread ()
24462475 trap_if (not thread .task .inst .may_leave )
2447- trap_if (not thread .task .inst .may_block and not opts .async_ )
2476+ trap_if (not thread .task .inst .may_block () and not opts .async_ )
24482477
24492478 e = thread .task .inst .handles .get (i )
24502479 trap_if (not isinstance (e , EndT ))
@@ -2499,7 +2528,7 @@ def canon_future_write(future_t, opts, i, ptr):
24992528def future_copy (EndT , BufferT , event_code , future_t , opts , i , ptr ):
25002529 thread = current_thread ()
25012530 trap_if (not thread .task .inst .may_leave )
2502- trap_if (not thread .task .inst .may_block and not opts .async_ )
2531+ trap_if (not thread .task .inst .may_block () and not opts .async_ )
25032532
25042533 e = thread .task .inst .handles .get (i )
25052534 trap_if (not isinstance (e , EndT ))
@@ -2552,7 +2581,7 @@ def canon_future_cancel_write(future_t, async_, i):
25522581def cancel_copy (EndT , event_code , stream_or_future_t , async_ , i ):
25532582 thread = current_thread ()
25542583 trap_if (not thread .task .inst .may_leave )
2555- trap_if (not thread .task .inst .may_block and not async_ )
2584+ trap_if (not thread .task .inst .may_block () and not async_ )
25562585 e = thread .task .inst .handles .get (i )
25572586 trap_if (not isinstance (e , EndT ))
25582587 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2620,22 +2649,22 @@ def thread_func():
26202649 task .register_thread (new_thread )
26212650 return [new_thread .index ]
26222651
2623- ### 🧵 `canon thread.resume-later `
2652+ ### 🧵 `canon thread.unsuspend `
26242653
2625- def canon_thread_resume_later (i ):
2654+ def canon_thread_unsuspend (i ):
26262655 thread = current_thread ()
26272656 trap_if (not thread .task .inst .may_leave )
26282657 other_thread = thread .task .inst .threads .get (i )
26292658 trap_if (not other_thread .suspended ())
2630- other_thread .resume_later ()
2659+ other_thread .unsuspend ()
26312660 return []
26322661
26332662### 🧵 `canon thread.suspend`
26342663
26352664def canon_thread_suspend (cancellable ):
26362665 thread = current_thread ()
26372666 trap_if (not thread .task .inst .may_leave )
2638- trap_if (not thread .task .inst .may_block )
2667+ trap_if (not thread .task .inst .may_block () )
26392668 cancelled = thread .suspend (cancellable )
26402669 return [cancelled ]
26412670
@@ -2647,26 +2676,54 @@ def canon_thread_yield(cancellable):
26472676 cancelled = thread .yield_ (cancellable )
26482677 return [cancelled ]
26492678
2650- ### 🧵 `canon thread.switch -to`
2679+ ### 🧵 `canon thread.suspend -to-suspended `
26512680
2652- def canon_thread_switch_to (cancellable , i ):
2681+ def canon_thread_suspend_to_suspended (cancellable , i ):
26532682 thread = current_thread ()
26542683 trap_if (not thread .task .inst .may_leave )
26552684 other_thread = thread .task .inst .threads .get (i )
26562685 trap_if (not other_thread .suspended ())
2657- cancelled = thread .switch_to (cancellable , other_thread )
2686+ cancelled = thread .suspend_to_suspended (cancellable , other_thread )
26582687 return [cancelled ]
26592688
2660- ### 🧵 `canon thread.yield-to`
2689+ ### 🧵 `canon thread.yield-to-suspended `
26612690
2662- def canon_thread_yield_to (cancellable , i ):
2691+ def canon_thread_yield_to_suspended (cancellable , i ):
26632692 thread = current_thread ()
26642693 trap_if (not thread .task .inst .may_leave )
26652694 other_thread = thread .task .inst .threads .get (i )
26662695 trap_if (not other_thread .suspended ())
2667- cancelled = thread .yield_to (cancellable , other_thread )
2696+ cancelled = thread .yield_to_suspended (cancellable , other_thread )
2697+ return [cancelled ]
2698+
2699+ ### 🧵 `canon thread.suspend-then-promote`
2700+
2701+ def canon_thread_suspend_then_promote (cancellable , i ):
2702+ thread = current_thread ()
2703+ trap_if (not thread .task .inst .may_leave )
2704+ trap_if (not thread .task .inst .may_block ())
2705+ other_thread = thread .task .inst .threads .get (i )
2706+ cancelled = thread .suspend_then_promote (cancellable , other_thread )
2707+ return [cancelled ]
2708+
2709+ ### 🧵 `canon thread.yield-then-promote`
2710+
2711+ def canon_thread_yield_then_promote (cancellable , i ):
2712+ thread = current_thread ()
2713+ trap_if (not thread .task .inst .may_leave )
2714+ other_thread = thread .task .inst .threads .get (i )
2715+ cancelled = thread .yield_then_promote (cancellable , other_thread )
26682716 return [cancelled ]
26692717
2718+ ### 🧵 `canon thread.exit`
2719+
2720+ def canon_thread_exit ():
2721+ thread = current_thread ()
2722+ trap_if (not thread .task .inst .may_leave )
2723+ trap_if (not thread .task .inst .may_block ())
2724+ thread .exit ()
2725+ assert (False )
2726+
26702727### 📝 `canon error-context.new`
26712728
26722729@dataclass
0 commit comments