Skip to content

Commit e25388c

Browse files
committed
CABI: improve and add cooperative thread built-ins
1 parent 39ae5c2 commit e25388c

File tree

2 files changed

+244
-57
lines changed

2 files changed

+244
-57
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 99 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
class Trap(BaseException): pass
1818
class CoreWebAssemblyException(BaseException): pass
19+
class ThreadExit(BaseException): pass
1920

2021
def trap():
2122
raise Trap()
@@ -294,7 +295,7 @@ class ComponentInstance:
294295
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
295296
threads: Table[Thread]
296297
may_leave: bool
297-
may_block: bool
298+
sync_before_return: bool
298299
backpressure: int
299300
exclusive: Optional[Task]
300301
num_waiting_to_enter: int
@@ -306,11 +307,17 @@ def __init__(self, store, parent = None):
306307
self.handles = Table()
307308
self.threads = Table()
308309
self.may_leave = True
309-
self.may_block = True
310+
self.sync_before_return = False
310311
self.backpressure = 0
311312
self.exclusive = None
312313
self.num_waiting_to_enter = 0
313314

315+
def ready_threads(self) -> list[Thread]:
316+
return [t for t in self.threads.array if t and t.waiting() and t.ready()]
317+
318+
def may_block(self):
319+
return not self.sync_before_return or len(self.ready_threads()) > 0
320+
314321
def reflexive_ancestors(self) -> set[ComponentInstance]:
315322
s = set()
316323
inst = self
@@ -487,7 +494,10 @@ def ready(self):
487494
def __init__(self, task, thread_func):
488495
def cont_func(cancelled):
489496
assert(self.running() and not cancelled)
490-
thread_func()
497+
try:
498+
thread_func()
499+
except ThreadExit:
500+
pass
491501
return None
492502
self.cont = cont_new(cont_func)
493503
self.ready_func = None
@@ -497,7 +507,7 @@ def cont_func(cancelled):
497507
self.storage = [0,0]
498508
assert(self.suspended())
499509

500-
def resume_later(self):
510+
def unsuspend(self):
501511
assert(self.suspended())
502512
self.ready_func = lambda: True
503513
self.task.inst.store.waiting.append(self)
@@ -507,18 +517,25 @@ def resume(self, cancelled):
507517
assert(not self.running() and (self.cancellable or not cancelled))
508518
if self.waiting():
509519
assert(cancelled or self.ready())
510-
self.ready_func = None
511-
self.task.inst.store.waiting.remove(self)
520+
self.stop_waiting()
512521
thread = self
513522
while thread is not None:
514523
cont = thread.cont
515524
thread.cont = None
516525
(thread.cont, switch_to) = resume(cont, cancelled, thread)
526+
if switch_to is None and self.task.inst.sync_before_return:
527+
switch_to = random.choice(self.task.inst.ready_threads())
528+
switch_to.stop_waiting()
517529
thread = switch_to
518530
cancelled = Cancelled.FALSE
519531

532+
def stop_waiting(self):
533+
assert(self.waiting())
534+
self.ready_func = None
535+
self.task.inst.store.waiting.remove(self)
536+
520537
def suspend(self, cancellable) -> Cancelled:
521-
assert(self.running() and self.task.inst.may_block)
538+
assert(self.running())
522539
if self.task.deliver_pending_cancel(cancellable):
523540
return Cancelled.TRUE
524541
self.cancellable = cancellable
@@ -527,7 +544,7 @@ def suspend(self, cancellable) -> Cancelled:
527544
return cancelled
528545

529546
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
530-
assert(self.running() and self.task.inst.may_block)
547+
assert(self.running())
531548
if self.task.deliver_pending_cancel(cancellable):
532549
return Cancelled.TRUE
533550
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -536,18 +553,10 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
536553
self.task.inst.store.waiting.append(self)
537554
return self.suspend(cancellable)
538555

539-
def yield_until(self, ready_func, cancellable) -> Cancelled:
540-
assert(self.running())
541-
if self.task.inst.may_block:
542-
return self.wait_until(ready_func, cancellable)
543-
else:
544-
assert(ready_func())
545-
return Cancelled.FALSE
546-
547556
def yield_(self, cancellable) -> Cancelled:
548-
return self.yield_until(lambda: True, cancellable)
557+
return self.wait_until(lambda: True, cancellable)
549558

550-
def switch_to(self, cancellable, other: Thread) -> Cancelled:
559+
def suspend_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
551560
assert(self.running() and other.suspended())
552561
if self.task.deliver_pending_cancel(cancellable):
553562
return Cancelled.TRUE
@@ -556,11 +565,31 @@ def switch_to(self, cancellable, other: Thread) -> Cancelled:
556565
assert(self.running() and (cancellable or not cancelled))
557566
return cancelled
558567

559-
def yield_to(self, cancellable, other: Thread) -> Cancelled:
568+
def yield_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
560569
assert(self.running() and other.suspended())
561570
self.ready_func = lambda: True
562571
self.task.inst.store.waiting.append(self)
563-
return self.switch_to(cancellable, other)
572+
return self.suspend_to_suspended(cancellable, other)
573+
574+
def suspend_then_promote(self, cancellable, other: Thread) -> ResumeArg:
575+
assert(self.running())
576+
if other.waiting() and other.ready():
577+
other.stop_waiting()
578+
return self.suspend_to_suspended(cancellable, other)
579+
else:
580+
return self.suspend(cancellable)
581+
582+
def yield_then_promote(self, cancellable, other: Thread) -> ResumeArg:
583+
assert(self.running())
584+
if other.waiting() and other.ready():
585+
other.stop_waiting()
586+
return self.yield_to_suspended(cancellable, other)
587+
else:
588+
return self.yield_(cancellable)
589+
590+
def exit(self):
591+
assert(self.running() and self.task.inst.may_block())
592+
raise ThreadExit()
564593

565594
#### Waitable State
566595

@@ -701,8 +730,8 @@ def has_backpressure():
701730
assert(self.inst.exclusive is None)
702731
self.inst.exclusive = self
703732
else:
704-
assert(self.inst.may_block)
705-
self.inst.may_block = False
733+
assert(not self.inst.sync_before_return)
734+
self.inst.sync_before_return = True
706735
self.register_thread(thread)
707736
return True
708737

@@ -753,8 +782,8 @@ def return_(self, result):
753782
trap_if(self.state == Task.State.RESOLVED)
754783
trap_if(self.num_borrows > 0)
755784
if not self.ft.async_:
756-
assert(not self.inst.may_block)
757-
self.inst.may_block = True
785+
assert(self.inst.sync_before_return)
786+
self.inst.sync_before_return = False
758787
assert(result is not None)
759788
self.on_resolve(result)
760789
self.state = Task.State.RESOLVED
@@ -2091,12 +2120,12 @@ def thread_func():
20912120
inst.exclusive = None
20922121
match code:
20932122
case CallbackCode.YIELD:
2094-
if thread.yield_until(lambda: not inst.exclusive, cancellable = True) == Cancelled.TRUE:
2123+
if thread.wait_until(lambda: not inst.exclusive, cancellable = True) == Cancelled.TRUE:
20952124
event = (EventCode.TASK_CANCELLED, 0, 0)
20962125
else:
20972126
event = (EventCode.NONE, 0, 0)
20982127
case CallbackCode.WAIT:
2099-
trap_if(not inst.may_block)
2128+
trap_if(not inst.may_block())
21002129
wset = inst.handles.get(si)
21012130
trap_if(not isinstance(wset, WaitableSet))
21022131
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2140,7 +2169,7 @@ def call_and_trap_on_throw(callee, args):
21402169
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21412170
thread = current_thread()
21422171
trap_if(not thread.task.inst.may_leave)
2143-
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
2172+
trap_if(not thread.task.inst.may_block() and ft.async_ and not opts.async_)
21442173

21452174
subtask = Subtask()
21462175
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2328,7 +2357,7 @@ def canon_waitable_set_new():
23282357
def canon_waitable_set_wait(cancellable, mem, si, ptr):
23292358
inst = current_thread().task.inst
23302359
trap_if(not inst.may_leave)
2331-
trap_if(not inst.may_block)
2360+
trap_if(not inst.may_block())
23322361
wset = inst.handles.get(si)
23332362
trap_if(not isinstance(wset, WaitableSet))
23342363
event = wset.wait(cancellable)
@@ -2383,7 +2412,7 @@ def canon_waitable_join(wi, si):
23832412
def canon_subtask_cancel(async_, i):
23842413
thread = current_thread()
23852414
trap_if(not thread.task.inst.may_leave)
2386-
trap_if(not thread.task.inst.may_block and not async_)
2415+
trap_if(not thread.task.inst.may_block() and not async_)
23872416
subtask = thread.task.inst.handles.get(i)
23882417
trap_if(not isinstance(subtask, Subtask))
23892418
trap_if(subtask.resolve_delivered())
@@ -2444,7 +2473,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24442473
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24452474
thread = current_thread()
24462475
trap_if(not thread.task.inst.may_leave)
2447-
trap_if(not thread.task.inst.may_block and not opts.async_)
2476+
trap_if(not thread.task.inst.may_block() and not opts.async_)
24482477

24492478
e = thread.task.inst.handles.get(i)
24502479
trap_if(not isinstance(e, EndT))
@@ -2499,7 +2528,7 @@ def canon_future_write(future_t, opts, i, ptr):
24992528
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
25002529
thread = current_thread()
25012530
trap_if(not thread.task.inst.may_leave)
2502-
trap_if(not thread.task.inst.may_block and not opts.async_)
2531+
trap_if(not thread.task.inst.may_block() and not opts.async_)
25032532

25042533
e = thread.task.inst.handles.get(i)
25052534
trap_if(not isinstance(e, EndT))
@@ -2552,7 +2581,7 @@ def canon_future_cancel_write(future_t, async_, i):
25522581
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25532582
thread = current_thread()
25542583
trap_if(not thread.task.inst.may_leave)
2555-
trap_if(not thread.task.inst.may_block and not async_)
2584+
trap_if(not thread.task.inst.may_block() and not async_)
25562585
e = thread.task.inst.handles.get(i)
25572586
trap_if(not isinstance(e, EndT))
25582587
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2620,22 +2649,22 @@ def thread_func():
26202649
task.register_thread(new_thread)
26212650
return [new_thread.index]
26222651

2623-
### 🧵 `canon thread.resume-later`
2652+
### 🧵 `canon thread.unsuspend`
26242653

2625-
def canon_thread_resume_later(i):
2654+
def canon_thread_unsuspend(i):
26262655
thread = current_thread()
26272656
trap_if(not thread.task.inst.may_leave)
26282657
other_thread = thread.task.inst.threads.get(i)
26292658
trap_if(not other_thread.suspended())
2630-
other_thread.resume_later()
2659+
other_thread.unsuspend()
26312660
return []
26322661

26332662
### 🧵 `canon thread.suspend`
26342663

26352664
def canon_thread_suspend(cancellable):
26362665
thread = current_thread()
26372666
trap_if(not thread.task.inst.may_leave)
2638-
trap_if(not thread.task.inst.may_block)
2667+
trap_if(not thread.task.inst.may_block())
26392668
cancelled = thread.suspend(cancellable)
26402669
return [cancelled]
26412670

@@ -2647,26 +2676,54 @@ def canon_thread_yield(cancellable):
26472676
cancelled = thread.yield_(cancellable)
26482677
return [cancelled]
26492678

2650-
### 🧵 `canon thread.switch-to`
2679+
### 🧵 `canon thread.suspend-to-suspended`
26512680

2652-
def canon_thread_switch_to(cancellable, i):
2681+
def canon_thread_suspend_to_suspended(cancellable, i):
26532682
thread = current_thread()
26542683
trap_if(not thread.task.inst.may_leave)
26552684
other_thread = thread.task.inst.threads.get(i)
26562685
trap_if(not other_thread.suspended())
2657-
cancelled = thread.switch_to(cancellable, other_thread)
2686+
cancelled = thread.suspend_to_suspended(cancellable, other_thread)
26582687
return [cancelled]
26592688

2660-
### 🧵 `canon thread.yield-to`
2689+
### 🧵 `canon thread.yield-to-suspended`
26612690

2662-
def canon_thread_yield_to(cancellable, i):
2691+
def canon_thread_yield_to_suspended(cancellable, i):
26632692
thread = current_thread()
26642693
trap_if(not thread.task.inst.may_leave)
26652694
other_thread = thread.task.inst.threads.get(i)
26662695
trap_if(not other_thread.suspended())
2667-
cancelled = thread.yield_to(cancellable, other_thread)
2696+
cancelled = thread.yield_to_suspended(cancellable, other_thread)
2697+
return [cancelled]
2698+
2699+
### 🧵 `canon thread.suspend-then-promote`
2700+
2701+
def canon_thread_suspend_then_promote(cancellable, i):
2702+
thread = current_thread()
2703+
trap_if(not thread.task.inst.may_leave)
2704+
trap_if(not thread.task.inst.may_block())
2705+
other_thread = thread.task.inst.threads.get(i)
2706+
cancelled = thread.suspend_then_promote(cancellable, other_thread)
2707+
return [cancelled]
2708+
2709+
### 🧵 `canon thread.yield-then-promote`
2710+
2711+
def canon_thread_yield_then_promote(cancellable, i):
2712+
thread = current_thread()
2713+
trap_if(not thread.task.inst.may_leave)
2714+
other_thread = thread.task.inst.threads.get(i)
2715+
cancelled = thread.yield_then_promote(cancellable, other_thread)
26682716
return [cancelled]
26692717

2718+
### 🧵 `canon thread.exit`
2719+
2720+
def canon_thread_exit():
2721+
thread = current_thread()
2722+
trap_if(not thread.task.inst.may_leave)
2723+
trap_if(not thread.task.inst.may_block())
2724+
thread.exit()
2725+
assert(False)
2726+
26702727
### 📝 `canon error-context.new`
26712728

26722729
@dataclass

0 commit comments

Comments
 (0)