@@ -196,11 +196,16 @@ def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve)
196196 return f (host_caller , on_start , on_resolve )
197197
198198 def tick (self ):
199+ if (thread := self .find_ready_thread ()):
200+ thread .resume (Cancelled .FALSE )
201+
202+ def find_ready_thread (self ) -> Optional [Thread ]:
199203 random .shuffle (self .waiting )
200204 for thread in self .waiting :
201205 if thread .ready ():
202- thread .resume (Cancelled .FALSE )
203- return
206+ return thread
207+ return None
208+
204209
205210FuncInst : Callable [[Optional [Supertask ], OnStart , OnResolve ], Call ]
206211
@@ -286,6 +291,7 @@ class ComponentInstance:
286291 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
287292 threads : Table [Thread ]
288293 may_leave : bool
294+ may_enter : bool
289295 backpressure : int
290296 exclusive : Optional [Task ]
291297 num_waiting_to_enter : int
@@ -297,6 +303,7 @@ def __init__(self, store, parent = None):
297303 self .handles = Table ()
298304 self .threads = Table ()
299305 self .may_leave = True
306+ self .may_enter = True
300307 self .backpressure = 0
301308 self .exclusive = None
302309 self .num_waiting_to_enter = 0
@@ -503,12 +510,16 @@ def resume(self, cancelled):
503510 while thread is not None :
504511 cont = thread .cont
505512 thread .cont = None
513+ assert (thread .task .inst .may_enter )
514+ thread .task .inst .may_enter = False
506515 (thread .cont , switch_to ) = resume (cont , cancelled , thread )
516+ assert (not thread .task .inst .may_enter )
517+ thread .task .inst .may_enter = True
507518 thread = switch_to
508519 cancelled = Cancelled .FALSE
509520
510521 def suspend (self , cancellable ) -> Cancelled :
511- assert (self .running () and self . task . may_block () )
522+ assert (self .running ())
512523 if self .task .deliver_pending_cancel (cancellable ):
513524 return Cancelled .TRUE
514525 self .cancellable = cancellable
@@ -517,7 +528,7 @@ def suspend(self, cancellable) -> Cancelled:
517528 return cancelled
518529
519530 def wait_until (self , ready_func , cancellable = False ) -> Cancelled :
520- assert (self .running () and self . task . may_block () )
531+ assert (self .running ())
521532 if self .task .deliver_pending_cancel (cancellable ):
522533 return Cancelled .TRUE
523534 if ready_func () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
@@ -526,16 +537,8 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
526537 self .task .inst .store .waiting .append (self )
527538 return self .suspend (cancellable )
528539
529- def yield_until (self , ready_func , cancellable ) -> Cancelled :
530- assert (self .running ())
531- if self .task .may_block ():
532- return self .wait_until (ready_func , cancellable )
533- else :
534- assert (ready_func ())
535- return Cancelled .FALSE
536-
537540 def yield_ (self , cancellable ) -> Cancelled :
538- return self .yield_until (lambda : True , cancellable )
541+ return self .wait_until (lambda : True , cancellable )
539542
540543 def switch_to (self , cancellable , other : Thread ) -> Cancelled :
541544 assert (self .running () and other .suspended ())
@@ -673,26 +676,22 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
673676 def needs_exclusive (self ):
674677 return not self .opts .async_ or self .opts .callback
675678
676- def may_block (self ):
677- return self .ft .async_ or self .state == Task .State .RESOLVED
678-
679679 def enter (self ):
680680 thread = current_thread ()
681- if self .ft .async_ :
682- def has_backpressure ():
683- return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
684- if has_backpressure () or self .inst .num_waiting_to_enter > 0 :
685- self .inst .num_waiting_to_enter += 1
686- self .waiting_to_enter = thread
687- cancelled = thread .wait_until (lambda : not has_backpressure (), cancellable = True )
688- self .waiting_to_enter = None
689- self .inst .num_waiting_to_enter -= 1
690- if cancelled :
691- self .cancel ()
692- return False
693- if self .needs_exclusive ():
694- assert (self .inst .exclusive is None )
695- self .inst .exclusive = self
681+ def has_backpressure ():
682+ return self .inst .backpressure > 0 or (self .needs_exclusive () and bool (self .inst .exclusive ))
683+ if has_backpressure () or self .inst .num_waiting_to_enter > 0 :
684+ self .inst .num_waiting_to_enter += 1
685+ self .waiting_to_enter = thread
686+ cancelled = thread .wait_until (lambda : not has_backpressure (), cancellable = True )
687+ self .waiting_to_enter = None
688+ self .inst .num_waiting_to_enter -= 1
689+ if cancelled :
690+ self .cancel ()
691+ return False
692+ if self .needs_exclusive ():
693+ assert (self .inst .exclusive is None )
694+ self .inst .exclusive = self
696695 self .register_thread (thread )
697696 return True
698697
@@ -2041,6 +2040,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20412040
20422041def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve ) -> Call :
20432042 trap_if (call_might_be_recursive (caller , inst ))
2043+ assert (inst .may_enter ) # true?
2044+
20442045 task = Task (opts , inst , ft , caller , on_resolve )
20452046 def thread_func ():
20462047 if not task .enter ():
@@ -2077,12 +2078,11 @@ def thread_func():
20772078 inst .exclusive = None
20782079 match code :
20792080 case CallbackCode .YIELD :
2080- if thread .yield_until (lambda : not inst .exclusive , cancellable = True ) == Cancelled .TRUE :
2081+ if thread .wait_until (lambda : not inst .exclusive , cancellable = True ) == Cancelled .TRUE :
20812082 event = (EventCode .TASK_CANCELLED , 0 , 0 )
20822083 else :
20832084 event = (EventCode .NONE , 0 , 0 )
20842085 case CallbackCode .WAIT :
2085- trap_if (not task .may_block ())
20862086 wset = inst .handles .get (si )
20872087 trap_if (not isinstance (wset , WaitableSet ))
20882088 event = wset .wait_until (lambda : not inst .exclusive , cancellable = True )
@@ -2098,6 +2098,11 @@ def thread_func():
20982098
20992099 thread = Thread (task , thread_func )
21002100 thread .resume (Cancelled .FALSE )
2101+ if not task .ft .async_ :
2102+ while task .state != Task .State .RESOLVED :
2103+ other = inst .store .find_ready_thread ()
2104+ trap_if (other is None )
2105+ other .resume (Cancelled .FALSE )
21012106 return task
21022107
21032108class CallbackCode (IntEnum ):
@@ -2125,7 +2130,6 @@ def call_and_trap_on_throw(callee, args):
21252130def canon_lower (opts , ft , callee : FuncInst , flat_args ):
21262131 thread = current_thread ()
21272132 trap_if (not thread .task .inst .may_leave )
2128- trap_if (not thread .task .may_block () and ft .async_ and not opts .async_ )
21292133
21302134 subtask = Subtask ()
21312135 cx = LiftLowerContext (opts , thread .task .inst , subtask )
@@ -2311,13 +2315,12 @@ def canon_waitable_set_new():
23112315### 🔀 `canon waitable-set.wait`
23122316
23132317def canon_waitable_set_wait (cancellable , mem , si , ptr ):
2314- task = current_thread ().task
2315- trap_if (not task .inst .may_leave )
2316- trap_if (not task .may_block ())
2317- wset = task .inst .handles .get (si )
2318+ inst = current_thread ().task .inst
2319+ trap_if (not inst .may_leave )
2320+ wset = inst .handles .get (si )
23182321 trap_if (not isinstance (wset , WaitableSet ))
23192322 event = wset .wait (cancellable )
2320- return unpack_event (mem , task . inst , ptr , event )
2323+ return unpack_event (mem , inst , ptr , event )
23212324
23222325def unpack_event (mem , inst , ptr , e : EventTuple ):
23232326 event , p1 , p2 = e
@@ -2368,7 +2371,6 @@ def canon_waitable_join(wi, si):
23682371def canon_subtask_cancel (async_ , i ):
23692372 thread = current_thread ()
23702373 trap_if (not thread .task .inst .may_leave )
2371- trap_if (not thread .task .may_block () and not async_ )
23722374 subtask = thread .task .inst .handles .get (i )
23732375 trap_if (not isinstance (subtask , Subtask ))
23742376 trap_if (subtask .resolve_delivered ())
@@ -2429,7 +2431,6 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24292431def stream_copy (EndT , BufferT , event_code , stream_t , opts , i , ptr , n ):
24302432 thread = current_thread ()
24312433 trap_if (not thread .task .inst .may_leave )
2432- trap_if (not thread .task .may_block () and not opts .async_ )
24332434
24342435 e = thread .task .inst .handles .get (i )
24352436 trap_if (not isinstance (e , EndT ))
@@ -2484,7 +2485,6 @@ def canon_future_write(future_t, opts, i, ptr):
24842485def future_copy (EndT , BufferT , event_code , future_t , opts , i , ptr ):
24852486 thread = current_thread ()
24862487 trap_if (not thread .task .inst .may_leave )
2487- trap_if (not thread .task .may_block () and not opts .async_ )
24882488
24892489 e = thread .task .inst .handles .get (i )
24902490 trap_if (not isinstance (e , EndT ))
@@ -2537,7 +2537,6 @@ def canon_future_cancel_write(future_t, async_, i):
25372537def cancel_copy (EndT , event_code , stream_or_future_t , async_ , i ):
25382538 thread = current_thread ()
25392539 trap_if (not thread .task .inst .may_leave )
2540- trap_if (not thread .task .may_block () and not async_ )
25412540 e = thread .task .inst .handles .get (i )
25422541 trap_if (not isinstance (e , EndT ))
25432542 trap_if (e .shared .t != stream_or_future_t .t )
@@ -2620,7 +2619,6 @@ def canon_thread_switch_to(cancellable, i):
26202619def canon_thread_suspend (cancellable ):
26212620 thread = current_thread ()
26222621 trap_if (not thread .task .inst .may_leave )
2623- trap_if (not thread .task .may_block ())
26242622 cancelled = thread .suspend (cancellable )
26252623 return [cancelled ]
26262624
0 commit comments