Skip to content

Commit 0fd51bd

Browse files
committed
WIP: instead, remove may_block
1 parent 40a9956 commit 0fd51bd

File tree

3 files changed

+246
-45
lines changed

3 files changed

+246
-45
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,16 @@ def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve)
196196
return f(host_caller, on_start, on_resolve)
197197

198198
def tick(self):
199+
if (thread := self.find_ready_thread()):
200+
thread.resume(Cancelled.FALSE)
201+
202+
def find_ready_thread(self) -> Optional[Thread]:
199203
random.shuffle(self.waiting)
200204
for thread in self.waiting:
201205
if thread.ready():
202-
thread.resume(Cancelled.FALSE)
203-
return
206+
return thread
207+
return None
208+
204209

205210
FuncInst: Callable[[Optional[Supertask], OnStart, OnResolve], Call]
206211

@@ -286,6 +291,7 @@ class ComponentInstance:
286291
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
287292
threads: Table[Thread]
288293
may_leave: bool
294+
may_enter: bool
289295
backpressure: int
290296
exclusive: Optional[Task]
291297
num_waiting_to_enter: int
@@ -297,6 +303,7 @@ def __init__(self, store, parent = None):
297303
self.handles = Table()
298304
self.threads = Table()
299305
self.may_leave = True
306+
self.may_enter = True
300307
self.backpressure = 0
301308
self.exclusive = None
302309
self.num_waiting_to_enter = 0
@@ -503,12 +510,16 @@ def resume(self, cancelled):
503510
while thread is not None:
504511
cont = thread.cont
505512
thread.cont = None
513+
assert(thread.task.inst.may_enter)
514+
thread.task.inst.may_enter = False
506515
(thread.cont, switch_to) = resume(cont, cancelled, thread)
516+
assert(not thread.task.inst.may_enter)
517+
thread.task.inst.may_enter = True
507518
thread = switch_to
508519
cancelled = Cancelled.FALSE
509520

510521
def suspend(self, cancellable) -> Cancelled:
511-
assert(self.running() and self.task.may_block())
522+
assert(self.running())
512523
if self.task.deliver_pending_cancel(cancellable):
513524
return Cancelled.TRUE
514525
self.cancellable = cancellable
@@ -517,7 +528,7 @@ def suspend(self, cancellable) -> Cancelled:
517528
return cancelled
518529

519530
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
520-
assert(self.running() and self.task.may_block())
531+
assert(self.running())
521532
if self.task.deliver_pending_cancel(cancellable):
522533
return Cancelled.TRUE
523534
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -526,16 +537,8 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
526537
self.task.inst.store.waiting.append(self)
527538
return self.suspend(cancellable)
528539

529-
def yield_until(self, ready_func, cancellable) -> Cancelled:
530-
assert(self.running())
531-
if self.task.may_block():
532-
return self.wait_until(ready_func, cancellable)
533-
else:
534-
assert(ready_func())
535-
return Cancelled.FALSE
536-
537540
def yield_(self, cancellable) -> Cancelled:
538-
return self.yield_until(lambda: True, cancellable)
541+
return self.wait_until(lambda: True, cancellable)
539542

540543
def switch_to(self, cancellable, other: Thread) -> Cancelled:
541544
assert(self.running() and other.suspended())
@@ -673,26 +676,22 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
673676
def needs_exclusive(self):
674677
return not self.opts.async_ or self.opts.callback
675678

676-
def may_block(self):
677-
return self.ft.async_ or self.state == Task.State.RESOLVED
678-
679679
def enter(self):
680680
thread = current_thread()
681-
if self.ft.async_:
682-
def has_backpressure():
683-
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
684-
if has_backpressure() or self.inst.num_waiting_to_enter > 0:
685-
self.inst.num_waiting_to_enter += 1
686-
self.waiting_to_enter = thread
687-
cancelled = thread.wait_until(lambda: not has_backpressure(), cancellable = True)
688-
self.waiting_to_enter = None
689-
self.inst.num_waiting_to_enter -= 1
690-
if cancelled:
691-
self.cancel()
692-
return False
693-
if self.needs_exclusive():
694-
assert(self.inst.exclusive is None)
695-
self.inst.exclusive = self
681+
def has_backpressure():
682+
return self.inst.backpressure > 0 or (self.needs_exclusive() and bool(self.inst.exclusive))
683+
if has_backpressure() or self.inst.num_waiting_to_enter > 0:
684+
self.inst.num_waiting_to_enter += 1
685+
self.waiting_to_enter = thread
686+
cancelled = thread.wait_until(lambda: not has_backpressure(), cancellable = True)
687+
self.waiting_to_enter = None
688+
self.inst.num_waiting_to_enter -= 1
689+
if cancelled:
690+
self.cancel()
691+
return False
692+
if self.needs_exclusive():
693+
assert(self.inst.exclusive is None)
694+
self.inst.exclusive = self
696695
self.register_thread(thread)
697696
return True
698697

@@ -2041,6 +2040,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
20412040

20422041
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
20432042
trap_if(call_might_be_recursive(caller, inst))
2043+
assert(inst.may_enter) # true?
2044+
20442045
task = Task(opts, inst, ft, caller, on_resolve)
20452046
def thread_func():
20462047
if not task.enter():
@@ -2077,12 +2078,11 @@ def thread_func():
20772078
inst.exclusive = None
20782079
match code:
20792080
case CallbackCode.YIELD:
2080-
if thread.yield_until(lambda: not inst.exclusive, cancellable = True) == Cancelled.TRUE:
2081+
if thread.wait_until(lambda: not inst.exclusive, cancellable = True) == Cancelled.TRUE:
20812082
event = (EventCode.TASK_CANCELLED, 0, 0)
20822083
else:
20832084
event = (EventCode.NONE, 0, 0)
20842085
case CallbackCode.WAIT:
2085-
trap_if(not task.may_block())
20862086
wset = inst.handles.get(si)
20872087
trap_if(not isinstance(wset, WaitableSet))
20882088
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2098,6 +2098,11 @@ def thread_func():
20982098

20992099
thread = Thread(task, thread_func)
21002100
thread.resume(Cancelled.FALSE)
2101+
if not task.ft.async_:
2102+
while task.state != Task.State.RESOLVED:
2103+
other = inst.store.find_ready_thread()
2104+
trap_if(other is None)
2105+
other.resume(Cancelled.FALSE)
21012106
return task
21022107

21032108
class CallbackCode(IntEnum):
@@ -2125,7 +2130,6 @@ def call_and_trap_on_throw(callee, args):
21252130
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21262131
thread = current_thread()
21272132
trap_if(not thread.task.inst.may_leave)
2128-
trap_if(not thread.task.may_block() and ft.async_ and not opts.async_)
21292133

21302134
subtask = Subtask()
21312135
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2311,13 +2315,12 @@ def canon_waitable_set_new():
23112315
### 🔀 `canon waitable-set.wait`
23122316

23132317
def canon_waitable_set_wait(cancellable, mem, si, ptr):
2314-
task = current_thread().task
2315-
trap_if(not task.inst.may_leave)
2316-
trap_if(not task.may_block())
2317-
wset = task.inst.handles.get(si)
2318+
inst = current_thread().task.inst
2319+
trap_if(not inst.may_leave)
2320+
wset = inst.handles.get(si)
23182321
trap_if(not isinstance(wset, WaitableSet))
23192322
event = wset.wait(cancellable)
2320-
return unpack_event(mem, task.inst, ptr, event)
2323+
return unpack_event(mem, inst, ptr, event)
23212324

23222325
def unpack_event(mem, inst, ptr, e: EventTuple):
23232326
event, p1, p2 = e
@@ -2368,7 +2371,6 @@ def canon_waitable_join(wi, si):
23682371
def canon_subtask_cancel(async_, i):
23692372
thread = current_thread()
23702373
trap_if(not thread.task.inst.may_leave)
2371-
trap_if(not thread.task.may_block() and not async_)
23722374
subtask = thread.task.inst.handles.get(i)
23732375
trap_if(not isinstance(subtask, Subtask))
23742376
trap_if(subtask.resolve_delivered())
@@ -2429,7 +2431,6 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24292431
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24302432
thread = current_thread()
24312433
trap_if(not thread.task.inst.may_leave)
2432-
trap_if(not thread.task.may_block() and not opts.async_)
24332434

24342435
e = thread.task.inst.handles.get(i)
24352436
trap_if(not isinstance(e, EndT))
@@ -2484,7 +2485,6 @@ def canon_future_write(future_t, opts, i, ptr):
24842485
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
24852486
thread = current_thread()
24862487
trap_if(not thread.task.inst.may_leave)
2487-
trap_if(not thread.task.may_block() and not opts.async_)
24882488

24892489
e = thread.task.inst.handles.get(i)
24902490
trap_if(not isinstance(e, EndT))
@@ -2537,7 +2537,6 @@ def canon_future_cancel_write(future_t, async_, i):
25372537
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25382538
thread = current_thread()
25392539
trap_if(not thread.task.inst.may_leave)
2540-
trap_if(not thread.task.may_block() and not async_)
25412540
e = thread.task.inst.handles.get(i)
25422541
trap_if(not isinstance(e, EndT))
25432542
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2620,7 +2619,6 @@ def canon_thread_switch_to(cancellable, i):
26202619
def canon_thread_suspend(cancellable):
26212620
thread = current_thread()
26222621
trap_if(not thread.task.inst.may_leave)
2623-
trap_if(not thread.task.may_block())
26242622
cancelled = thread.suspend(cancellable)
26252623
return [cancelled]
26262624

design/mvp/canonical-abi/run_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2846,7 +2846,7 @@ def mk_task(supertask, inst):
28462846
test_async_to_async()
28472847
test_async_callback()
28482848
test_callback_interleaving()
2849-
test_sync_ignores_backpressure()
2849+
#test_sync_ignores_backpressure()
28502850
test_async_to_sync()
28512851
test_async_backpressure()
28522852
test_sync_using_wait()

0 commit comments

Comments
 (0)