Skip to content

Commit 7703906

Browse files
committed
Fix PyThreadState_EnsureFromView() and add a test.
1 parent 3320241 commit 7703906

4 files changed

Lines changed: 98 additions & 22 deletions

File tree

Include/pystate.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ PyAPI_FUNC(void) PyInterpreterView_Close(PyInterpreterView *view);
134134
PyAPI_FUNC(PyInterpreterView *) PyInterpreterView_FromMain(void);
135135

136136
PyAPI_FUNC(PyThreadState *) PyThreadState_Ensure(PyInterpreterGuard *guard);
137+
PyAPI_FUNC(PyThreadState *) PyThreadState_EnsureFromView(PyInterpreterView *view);
137138
PyAPI_FUNC(void) PyThreadState_Release(PyThreadState *tstate);
138139

139140
#ifndef Py_LIMITED_API

Modules/_testcapimodule.c

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2767,11 +2767,14 @@ test_interp_view_after_shutdown(PyObject *self, PyObject *unused)
27672767
PyThreadState *save_tstate = PyThreadState_Swap(NULL);
27682768
PyThreadState *interp_tstate = Py_NewInterpreter();
27692769
if (interp_tstate == NULL) {
2770+
PyThreadState_Swap(save_tstate);
27702771
return PyErr_NoMemory();
27712772
}
27722773

27732774
PyInterpreterView *view = PyInterpreterView_FromCurrent();
27742775
if (view == NULL) {
2776+
Py_EndInterpreter(interp_tstate);
2777+
PyThreadState_Swap(save_tstate);
27752778
return PyErr_NoMemory();
27762779
}
27772780

@@ -2789,6 +2792,64 @@ test_interp_view_after_shutdown(PyObject *self, PyObject *unused)
27892792
Py_RETURN_NONE;
27902793
}
27912794

2795+
static PyObject *
2796+
test_thread_state_ensure_view(PyObject *self, PyObject *unused)
2797+
{
2798+
// For simplicity's sake, we assume that functions won't fail due to being
2799+
// out of memory.
2800+
PyThreadState *save_tstate = PyThreadState_Swap(NULL);
2801+
PyThreadState *interp_tstate = Py_NewInterpreter();
2802+
assert(interp_tstate != NULL);
2803+
assert(PyInterpreterState_Get() == PyThreadState_GetInterpreter(interp_tstate));
2804+
2805+
PyInterpreterView *main_view = PyInterpreterView_FromMain();
2806+
assert(main_view != NULL);
2807+
2808+
PyInterpreterView *view = PyInterpreterView_FromCurrent();
2809+
assert(view != NULL);
2810+
2811+
Py_BEGIN_ALLOW_THREADS;
2812+
PyThreadState *tstate = PyThreadState_EnsureFromView(view);
2813+
assert(tstate != NULL);
2814+
assert(PyThreadState_Get() == interp_tstate);
2815+
2816+
// Test a nested call
2817+
PyThreadState *tstate2 = PyThreadState_EnsureFromView(view);
2818+
assert(PyThreadState_Get() == interp_tstate);
2819+
2820+
// We're in a new interpreter now. PyThreadState_EnsureFromView() should
2821+
// now create a new thread state.
2822+
PyThreadState *main_tstate = PyThreadState_EnsureFromView(main_view);
2823+
assert(main_tstate == interp_tstate); // The old thread state
2824+
assert(PyInterpreterState_Get() == PyInterpreterState_Main());
2825+
2826+
// Going back to the old interpreter should create a new thread state again.
2827+
PyThreadState *tstate3 = PyThreadState_EnsureFromView(view);
2828+
assert(PyInterpreterState_Get() == PyThreadState_GetInterpreter(interp_tstate));
2829+
assert(PyThreadState_Get() != interp_tstate);
2830+
PyThreadState_Release(tstate3);
2831+
PyThreadState_Release(main_tstate);
2832+
2833+
// We're back in the original interpreter. PyThreadState_EnsureFromView() should
2834+
// no longer create a new thread state.
2835+
assert(PyThreadState_Get() == interp_tstate);
2836+
PyThreadState *tstate4 = PyThreadState_EnsureFromView(view);
2837+
assert(PyThreadState_Get() == interp_tstate);
2838+
PyThreadState_Release(tstate4);
2839+
PyThreadState_Release(tstate2);
2840+
PyThreadState_Release(tstate);
2841+
assert(PyThreadState_GetUnchecked() == NULL);
2842+
Py_END_ALLOW_THREADS;
2843+
2844+
assert(PyThreadState_Get() == interp_tstate);
2845+
PyInterpreterView_Close(view);
2846+
PyInterpreterView_Close(main_view);
2847+
Py_EndInterpreter(interp_tstate);
2848+
PyThreadState_Swap(save_tstate);
2849+
2850+
Py_RETURN_NONE;
2851+
}
2852+
27922853

27932854
static PyObject*
27942855
test_soft_deprecated_macros(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(args))
@@ -2927,6 +2988,7 @@ static PyMethodDef TestMethods[] = {
29272988
{"test_thread_state_ensure_nested", test_thread_state_ensure_nested, METH_NOARGS},
29282989
{"test_thread_state_ensure_crossinterp", test_thread_state_ensure_crossinterp, METH_NOARGS},
29292990
{"test_interp_view_after_shutdown", test_interp_view_after_shutdown, METH_NOARGS},
2991+
{"test_thread_state_ensure_view", test_thread_state_ensure_view, METH_NOARGS},
29302992
{NULL, NULL} /* sentinel */
29312993
};
29322994

Python/pylifecycle.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2340,7 +2340,7 @@ make_pre_finalization_calls(PyThreadState *tstate, int subinterpreters)
23402340
|| interp_has_atexit_callbacks(interp)
23412341
|| interp_has_pending_calls(interp)
23422342
|| has_subinterpreters
2343-
|| interp->finalization_guards.countdown > 0);
2343+
|| _Py_atomic_load_ssize_acquire(&interp->finalization_guards.countdown) > 0);
23442344
if (!should_continue) {
23452345
break;
23462346
}

Python/pystate.c

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3344,29 +3344,23 @@ _PyInterpreterGuard_GetInterpreter(PyInterpreterGuard *guard)
33443344
return guard->interp;
33453345
}
33463346

3347-
static PyInterpreterGuard *
3348-
try_acquire_interp_guard(PyInterpreterState *interp)
3347+
static int
3348+
try_acquire_interp_guard(PyInterpreterState *interp, PyInterpreterGuard *guard)
33493349
{
33503350
assert(interp != NULL);
33513351
_PyRWMutex_RLock(&interp->finalization_guards.lock);
33523352

33533353
if (_PyInterpreterState_GetFinalizing(interp) != NULL) {
33543354
_PyRWMutex_RUnlock(&interp->finalization_guards.lock);
33553355
assert(_Py_atomic_load_ssize_relaxed(&interp->finalization_guards.countdown) == 0);
3356-
return NULL;
3357-
}
3358-
3359-
PyInterpreterGuard *guard = PyMem_RawMalloc(sizeof(PyInterpreterGuard));
3360-
if (guard == NULL) {
3361-
_PyRWMutex_RUnlock(&interp->finalization_guards.lock);
3362-
return NULL;
3356+
return -1;
33633357
}
33643358

33653359
_Py_atomic_add_ssize(&interp->finalization_guards.countdown, 1);
33663360
_PyRWMutex_RUnlock(&interp->finalization_guards.lock);
33673361

33683362
guard->interp = interp;
3369-
return guard;
3363+
return 0;
33703364
}
33713365

33723366
PyInterpreterGuard *
@@ -3375,8 +3369,14 @@ PyInterpreterGuard_FromCurrent(void)
33753369
PyInterpreterState *interp = _PyInterpreterState_GET();
33763370
assert(interp != NULL);
33773371

3378-
PyInterpreterGuard *guard = try_acquire_interp_guard(interp);
3372+
PyInterpreterGuard *guard = PyMem_RawMalloc(sizeof(PyInterpreterGuard));
33793373
if (guard == NULL) {
3374+
PyErr_NoMemory();
3375+
return NULL;
3376+
}
3377+
3378+
if (try_acquire_interp_guard(interp, guard) < 0) {
3379+
PyMem_RawFree(guard);
33803380
PyErr_SetString(PyExc_PythonFinalizationError,
33813381
"cannot acquire finalization guard anymore");
33823382
return NULL;
@@ -3396,11 +3396,6 @@ PyInterpreterGuard_Close(PyInterpreterGuard *guard)
33963396
_PyRWMutex_RUnlock(&interp->finalization_guards.lock);
33973397

33983398
assert(old > 0);
3399-
if (old <= 0) {
3400-
Py_FatalError("interpreter has negative guard count, likely due"
3401-
" to an extra PyInterpreterGuard_Close() call");
3402-
}
3403-
34043399
PyMem_RawFree(guard);
34053400
}
34063401

@@ -3440,18 +3435,32 @@ PyInterpreterGuard_FromView(PyInterpreterView *view)
34403435
int64_t interp_id = view->id;
34413436
assert(interp_id >= 0);
34423437

3438+
// This allocation has to happen before we acquire the runtime lock, because
3439+
// PyMem_RawMalloc() might call some weird callback (such as tracemalloc)
3440+
// that tries to re-entrantly acquire the lock.
3441+
PyInterpreterGuard *guard = PyMem_RawMalloc(sizeof(PyInterpreterGuard));
3442+
if (guard == NULL) {
3443+
return NULL;
3444+
}
3445+
34433446
// Interpreters cannot be deleted while we hold the runtime lock.
34443447
_PyRuntimeState *runtime = &_PyRuntime;
34453448
HEAD_LOCK(runtime);
34463449
PyInterpreterState *interp = interp_look_up_id(runtime, interp_id);
34473450
if (interp == NULL) {
34483451
HEAD_UNLOCK(runtime);
3449-
return 0;
3452+
PyMem_RawFree(guard);
3453+
return NULL;
34503454
}
34513455

3452-
PyInterpreterGuard *guard = try_acquire_interp_guard(interp);
3456+
int result = try_acquire_interp_guard(interp, guard);
34533457
HEAD_UNLOCK(runtime);
34543458

3459+
if (result < 0) {
3460+
PyMem_RawFree(guard);
3461+
return NULL;
3462+
}
3463+
34553464
assert(guard == NULL || guard->interp != NULL);
34563465
return guard;
34573466
}
@@ -3528,20 +3537,24 @@ PyThreadState_EnsureFromView(PyInterpreterView *view)
35283537
return NULL;
35293538
}
35303539

3531-
PyThreadState *tstate = PyThreadState_Ensure(guard);
3532-
if (tstate == NULL) {
3540+
PyThreadState *result_tstate = PyThreadState_Ensure(guard);
3541+
if (result_tstate == NULL) {
35333542
PyInterpreterGuard_Close(guard);
35343543
return NULL;
35353544
}
35363545

3546+
PyThreadState *tstate = current_fast_get();
3547+
assert(tstate != NULL);
3548+
35373549
if (tstate->ensure.owned_guard != NULL) {
35383550
assert(tstate->ensure.owned_guard->interp == guard->interp);
35393551
PyInterpreterGuard_Close(guard);
35403552
} else {
3553+
assert(tstate->ensure.owned_guard == NULL);
35413554
tstate->ensure.owned_guard = guard;
35423555
}
35433556

3544-
return tstate;
3557+
return result_tstate;
35453558
}
35463559

35473560
void

0 commit comments

Comments
 (0)