diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index f16611b29a2658c..60aaa5cd548cbf0 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -10,6 +10,7 @@ import unittest from collections import deque from contextlib import _GeneratorContextManager, contextmanager, nullcontext +from _testinternalcapi import SelfInterruptingContextManager def do_with(obj): @@ -850,5 +851,21 @@ def exit_raises(): expected) +class InterruptDuringEnter(unittest.TestCase): + + def test_exit_called_after_interrupt(self): + cm = SelfInterruptingContextManager() + self.assertFalse(cm.within()) + try: + with cm: + self.assertTrue(cm.within()) + except KeyboardInterrupt: + self.assertFalse(cm.within()) + return + except: + self.fail("Wrong exception raised") + self.fail("No exception raised") + + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2026-06-04-12-53-10.gh-issue-148874.r121cG.rst b/Misc/NEWS.d/next/Core_and_Builtins/2026-06-04-12-53-10.gh-issue-148874.r121cG.rst new file mode 100644 index 000000000000000..95f93338beb9129 --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2026-06-04-12-53-10.gh-issue-148874.r121cG.rst @@ -0,0 +1,3 @@ +Ignore interrupts immediately after calling the ``__enter__`` method of a +context menager in a ``with`` statement. This ensures that the ``__exit__`` +method is always called in a ``with`` statement. diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c index e3de9006d5a427f..cb7889e614725f4 100644 --- a/Modules/_testinternalcapi.c +++ b/Modules/_testinternalcapi.c @@ -3196,6 +3196,66 @@ test_thread_state_ensure_from_view_interp_switch(PyObject *self, PyObject *unuse Py_RETURN_NONE; } +/* Self interrupting context manager */ + +typedef struct { + PyObject_HEAD + int within; +} SelfInterruptingContextManagerObject; + +static PyObject * +new_self_interrupting(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + SelfInterruptingContextManagerObject *self = + (SelfInterruptingContextManagerObject *)type->tp_alloc(type, 0); + if (self != NULL) { + self->within = 0; + } + return (PyObject *)self; +} + +static PyObject * +self_interrupting_enter(PyObject *op, PyObject *Py_UNUSED(dummy)) +{ + ((SelfInterruptingContextManagerObject *)op)->within = 1; + PyThreadState *tstate = PyThreadState_Get(); + PyObject *ki = Py_NewRef(PyExc_KeyboardInterrupt); + PyObject *old_exc = _Py_atomic_exchange_ptr(&tstate->async_exc, ki); + _Py_set_eval_breaker_bit(tstate, _PY_ASYNC_EXCEPTION_BIT); + Py_XDECREF(old_exc); + + return Py_NewRef(op); +} + +static PyObject * +self_interrupting_within(PyObject *op, PyObject *Py_UNUSED(dummy)) +{ + return PyBool_FromLong(((SelfInterruptingContextManagerObject *)op)->within); +} + +static PyObject * +self_interrupting_exit(PyObject *op, PyObject *Py_UNUSED(args)) { + ((SelfInterruptingContextManagerObject *)op)->within = 0; + Py_RETURN_NONE; +} + +static PyMethodDef self_interrupting_methods[] = { + {"__enter__", self_interrupting_enter, METH_NOARGS, NULL}, + {"within", self_interrupting_within, METH_NOARGS, NULL}, + {"__exit__", self_interrupting_exit, METH_VARARGS, NULL}, + {NULL, NULL} /* sentinel */ +}; + +static PyTypeObject SelfInterruptingContextManager_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_testcapi.SelfInterruptingContextManager", + sizeof(SelfInterruptingContextManagerObject), + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, + .tp_new = new_self_interrupting, + .tp_methods = self_interrupting_methods, +}; + + static PyMethodDef module_functions[] = { {"get_configs", get_configs, METH_NOARGS}, {"get_eval_frame_stats", get_eval_frame_stats, METH_NOARGS, NULL}, @@ -3418,6 +3478,11 @@ module_exec(PyObject *module) } #endif + if (PyType_Ready(&SelfInterruptingContextManager_Type) < 0) { + return 1; + } + PyModule_AddObject(module, "SelfInterruptingContextManager", (PyObject *)&SelfInterruptingContextManager_Type); + return 0; } diff --git a/Modules/_testinternalcapi/test_cases.c.h b/Modules/_testinternalcapi/test_cases.c.h index fdc077c9549a144..2c2c4a4182c96e4 100644 --- a/Modules/_testinternalcapi/test_cases.c.h +++ b/Modules/_testinternalcapi/test_cases.c.h @@ -1927,7 +1927,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2453,7 +2453,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2546,7 +2546,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2635,7 +2635,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2746,7 +2746,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2860,7 +2860,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -3191,7 +3191,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -3703,7 +3703,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4122,7 +4122,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4242,7 +4242,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4364,7 +4364,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4499,7 +4499,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4576,7 +4576,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4880,7 +4880,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4958,7 +4958,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -7283,7 +7283,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -7465,7 +7465,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); diff --git a/Python/bytecodes.c b/Python/bytecodes.c index 31596e0bc7a31d2..a19dc88784001fa 100644 --- a/Python/bytecodes.c +++ b/Python/bytecodes.c @@ -161,7 +161,7 @@ dummy_func( } replaced op(_CHECK_PERIODIC_AT_END, (--)) { - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); ERROR_IF(err != 0); } diff --git a/Python/ceval_macros.h b/Python/ceval_macros.h index b13884bf8214d41..f19adfa0cfcfc15 100644 --- a/Python/ceval_macros.h +++ b/Python/ceval_macros.h @@ -528,6 +528,22 @@ check_periodics(PyThreadState *tstate) { return 0; } +static inline int +check_periodics_at_end(PyThreadState *tstate, _PyInterpreterFrame *frame) { + _Py_CHECK_EMSCRIPTEN_SIGNALS_PERIODICALLY(); + QSBR_QUIESCENT_STATE(tstate); + if (_Py_atomic_load_uintptr_relaxed(&tstate->eval_breaker) & _PY_EVAL_EVENTS_MASK) { + // Do not handle pending interrupts if the previous instruction was LOAD_SPECIAL + // This may also not handle interrupts if a cache looks like LOAD_SPECIAL, + // but this is benign as we won't skip periodic checks indefinitely. + if (frame->instr_ptr[-1].op.code == LOAD_SPECIAL) { + return 0; + } + return _Py_HandlePending(tstate); + } + return 0; +} + // Mark the generator as executing. Returns true if the state was changed, // false if it was already executing or finished. static inline bool diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h index 0da86abed67f63b..bbb93e85bf26055 100644 --- a/Python/generated_cases.c.h +++ b/Python/generated_cases.c.h @@ -1927,7 +1927,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2453,7 +2453,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2546,7 +2546,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2635,7 +2635,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2746,7 +2746,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -2860,7 +2860,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -3191,7 +3191,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -3703,7 +3703,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4122,7 +4122,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4242,7 +4242,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4364,7 +4364,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4499,7 +4499,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4576,7 +4576,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4880,7 +4880,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -4958,7 +4958,7 @@ { assert(stack_pointer == _PyFrame_GetStackPointer(frame)); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -7283,7 +7283,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); @@ -7465,7 +7465,7 @@ ASSERT_WITHIN_STACK_BOUNDS(__FILE__, __LINE__); _PyFrame_SetStackPointer(frame, stack_pointer); _PyFrame_StackPointerValidate(frame); - int err = check_periodics(tstate); + int err = check_periodics_at_end(tstate, frame); _PyFrame_StackPointerInvalidate(frame); if (err != 0) { JUMP_TO_LABEL(error); diff --git a/Python/optimizer.c b/Python/optimizer.c index e95e4b5e24b2c54..c9f6ebdb62f07b2 100644 --- a/Python/optimizer.c +++ b/Python/optimizer.c @@ -956,10 +956,15 @@ _PyJit_translate_single_bytecode_to_trace( case OPARG_REPLACED: uop = _PyUOp_Replacements[uop]; assert(uop != 0); - uint32_t next_inst = target + 1 + _PyOpcode_Caches[_PyOpcode_Deopt[opcode]]; if (uop == _TIER2_RESUME_CHECK) { - target = next_inst; + if (this_instr[-1].op.code == LOAD_SPECIAL) { + // Don't check eval breaker immediately after LOAD_SPECIAL + uop = _NOP; + } + else { + target = next_inst; + } } else { int extended_arg = orig_oparg > 255; diff --git a/Tools/c-analyzer/cpython/ignored.tsv b/Tools/c-analyzer/cpython/ignored.tsv index bf08e5568205e7a..6e18593ad698570 100644 --- a/Tools/c-analyzer/cpython/ignored.tsv +++ b/Tools/c-analyzer/cpython/ignored.tsv @@ -577,6 +577,7 @@ Modules/_testimportmultiple.c - _testimportmultiple - Modules/_testinternalcapi.c - pending_identify_result - Modules/_testinternalcapi.c - Test_EvalFrame_Resumes - Modules/_testinternalcapi.c - Test_EvalFrame_Loads - +Modules/_testinternalcapi.c - SelfInterruptingContextManager_Type - Modules/_testinternalcapi/interpreter.c - Test_EvalFrame_Resumes - Modules/_testinternalcapi/interpreter.c - Test_EvalFrame_Loads - Modules/_testlimitedcapi/slots.c - TestMethods -