@@ -122,6 +122,7 @@ get_module_state(PyObject *mod)
122122}
123123
124124static struct PyModuleDef _decimal_module ;
125+ static PyType_Spec dec_spec ;
125126
126127static inline decimal_state *
127128get_module_state_by_def (PyTypeObject * tp )
@@ -134,10 +135,16 @@ get_module_state_by_def(PyTypeObject *tp)
134135static inline decimal_state *
135136find_state_left_or_right (PyObject * left , PyObject * right )
136137{
137- PyObject * mod = _PyType_GetModuleByDef2 (Py_TYPE (left ), Py_TYPE (right ),
138- & _decimal_module );
139- assert (mod != NULL );
140- return get_module_state (mod );
138+ PyTypeObject * base ;
139+ if (PyType_GetBaseByToken (Py_TYPE (left ), & dec_spec , & base ) != 1 ) {
140+ assert (!PyErr_Occurred ());
141+ PyType_GetBaseByToken (Py_TYPE (right ), & dec_spec , & base );
142+ }
143+ assert (base != NULL );
144+ void * state = _PyType_GetModuleState (base );
145+ assert (state != NULL );
146+ Py_DECREF (base );
147+ return (decimal_state * )state ;
141148}
142149
143150
@@ -745,7 +752,7 @@ signaldict_richcompare(PyObject *v, PyObject *w, int op)
745752{
746753 PyObject * res = Py_NotImplemented ;
747754
748- decimal_state * state = find_state_left_or_right ( v , w );
755+ decimal_state * state = get_module_state_by_def ( Py_TYPE ( v ) );
749756 assert (PyDecSignalDict_Check (state , v ));
750757
751758 if ((SdFlagAddr (v ) == NULL ) || (SdFlagAddr (w ) == NULL )) {
@@ -5041,6 +5048,7 @@ static PyMethodDef dec_methods [] =
50415048};
50425049
50435050static PyType_Slot dec_slots [] = {
5051+ {Py_tp_token , Py_TP_USE_SPEC },
50445052 {Py_tp_dealloc , dec_dealloc },
50455053 {Py_tp_getattro , PyObject_GenericGetAttr },
50465054 {Py_tp_traverse , dec_traverse },
0 commit comments