diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 8f8d74255a87..9a279ae65b16 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -331,6 +331,20 @@ def emit_line() -> None: if emitter.capi_version < (3, 12): fields["tp_dictoffset"] = base_size fields["tp_weaklistoffset"] = weak_offset + elif cl.supports_weakref and emitter.capi_version < (3, 12): + # __weakref__ lives right after the struct + # TODO: It should get a member in the struct instead of doing this nonsense. + emitter.emit_lines( + f"PyMemberDef {members_name}[] = {{", + f'{{"__weakref__", T_OBJECT_EX, {base_size}, 0, NULL}},', + "{0}", + "};", + ) + fields["tp_members"] = members_name + fields["tp_basicsize"] = f"{base_size} + sizeof(PyObject *)" + # versions >= 3.12 set Py_TPFLAGS_MANAGED_WEAKREF flag instead + # https://docs.python.org/3.12/extending/newtypes.html#weak-reference-support + fields["tp_weaklistoffset"] = base_size else: fields["tp_basicsize"] = base_size @@ -391,6 +405,9 @@ def emit_line() -> None: fields["tp_call"] = "PyVectorcall_Call" if has_managed_dict(cl, emitter): flags.append("Py_TPFLAGS_MANAGED_DICT") + if cl.supports_weakref and emitter.capi_version >= (3, 12): + flags.append("Py_TPFLAGS_MANAGED_WEAKREF") + fields["tp_flags"] = " | ".join(flags) fields["tp_doc"] = f"PyDoc_STR({native_class_doc_initializer(cl)})" @@ -886,6 +903,12 @@ def generate_traverse_for_class(cl: ClassIR, func_name: str, emitter: Emitter) - f"*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({struct_name})))", object_rprimitive, ) + elif cl.supports_weakref and emitter.capi_version < (3, 12): + struct_name = cl.struct_name(emitter.names) + # __weakref__ lives right after the struct + emitter.emit_gc_visit( + f"*((PyObject **)((char *)self + sizeof({struct_name})))", object_rprimitive + ) emitter.emit_line("return 0;") emitter.emit_line("}") @@ -909,6 +932,12 @@ def generate_clear_for_class(cl: ClassIR, func_name: str, emitter: Emitter) -> N f"*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({struct_name})))", object_rprimitive, ) + elif cl.supports_weakref and emitter.capi_version < (3, 12): + struct_name = cl.struct_name(emitter.names) + # __weakref__ lives right after the struct + emitter.emit_gc_clear( + f"*((PyObject **)((char *)self + sizeof({struct_name})))", object_rprimitive + ) emitter.emit_line("return 0;") emitter.emit_line("}") @@ -923,6 +952,8 @@ def generate_dealloc_for_class( emitter.emit_line("static void") emitter.emit_line(f"{dealloc_func_name}({cl.struct_name(emitter.names)} *self)") emitter.emit_line("{") + if cl.supports_weakref: + emitter.emit_line("PyObject_ClearWeakRefs((PyObject *) self);") if has_tp_finalize: emitter.emit_line("PyObject *type, *value, *traceback;") emitter.emit_line("PyErr_Fetch(&type, &value, &traceback);") diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 50a8225b4a68..845a2e2fce3e 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -109,6 +109,8 @@ def __init__( self.inherits_python = False # Do instances of this class have __dict__? self.has_dict = False + # Do instances of this class have __weakref__? + self.supports_weakref = False # Do we allow interpreted subclasses? Derived from a mypyc_attr. self.allow_interpreted_subclasses = False # Does this class need getseters to be generated for its attributes? (getseters are also @@ -384,6 +386,7 @@ def serialize(self) -> JsonDict: "is_final_class": self.is_final_class, "inherits_python": self.inherits_python, "has_dict": self.has_dict, + "supports_weakref": self.supports_weakref, "allow_interpreted_subclasses": self.allow_interpreted_subclasses, "needs_getseters": self.needs_getseters, "_serializable": self._serializable, @@ -444,6 +447,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR: ir.is_final_class = data["is_final_class"] ir.inherits_python = data["inherits_python"] ir.has_dict = data["has_dict"] + ir.supports_weakref = data.get("supports_weakref", False) ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"] ir.needs_getseters = data["needs_getseters"] ir._serializable = data["_serializable"] diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 9f3c7fc6f270..3acd85a3806b 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -368,6 +368,9 @@ def prepare_class_def( if attrs.get("serializable") is True: # Supports copy.copy and pickle (including subclasses) ir._serializable = True + if attrs.get("supports_weakref") is True: + # Has a tp_weakrefoffset slot allowing the creation of weak references (including subclasses) + ir.supports_weakref = True free_list_len = attrs.get("free_list_len") if free_list_len is not None: diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index 3028e940f7f9..c66db95a8774 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -33,14 +33,24 @@ from mypyc.errors import Errors MYPYC_ATTRS: Final[frozenset[MypycAttr]] = frozenset( - ["native_class", "allow_interpreted_subclasses", "serializable", "free_list_len"] + [ + "native_class", + "allow_interpreted_subclasses", + "serializable", + "supports_weakref", + "free_list_len", + ] ) DATACLASS_DECORATORS: Final = frozenset(["dataclasses.dataclass", "attr.s", "attr.attrs"]) MypycAttr = Literal[ - "native_class", "allow_interpreted_subclasses", "serializable", "free_list_len" + "native_class", + "allow_interpreted_subclasses", + "serializable", + "supports_weakref", + "free_list_len", ] @@ -48,6 +58,7 @@ class MypycAttrs(TypedDict): native_class: NotRequired[bool] allow_interpreted_subclasses: NotRequired[bool] serializable: NotRequired[bool] + supports_weakref: NotRequired[bool] free_list_len: NotRequired[int] diff --git a/mypyc/irbuild/vtable.py b/mypyc/irbuild/vtable.py index 2d4f7261e4ca..766b4086c594 100644 --- a/mypyc/irbuild/vtable.py +++ b/mypyc/irbuild/vtable.py @@ -15,6 +15,8 @@ def compute_vtable(cls: ClassIR) -> None: if not cls.is_generated: cls.has_dict = any(x.inherits_python for x in cls.mro) + # TODO: define more weakref triggers + cls.supports_weakref = cls.supports_weakref or cls.has_dict for t in cls.mro[1:]: # Make sure all ancestors are processed first diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 7954535d5dea..b05e15d15980 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1451,6 +1451,48 @@ class TestOverload: def __mypyc_generator_helper__(self, x: Any) -> Any: return x +[case testMypycAttrSupportsWeakref] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +obj = WeakrefClass() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefInheritance] +import weakref +from mypy_extensions import mypyc_attr + +@mypyc_attr(supports_weakref=True) +class WeakrefClass: + pass + +class WeakrefInheritor(WeakrefClass): + pass + +obj = WeakrefInheritor() +ref = weakref.ref(obj) +assert ref() is obj + +[case testMypycAttrSupportsWeakrefSubclass] +import weakref +from mypy_extensions import mypyc_attr + +class NativeClass: + pass + +@mypyc_attr(supports_weakref=True) +class WeakrefSubclass(NativeClass): + pass + +obj = WeakrefSubclass() +ref = weakref.ref(obj) +assert ref() is obj + [case testNativeBufferFastPath] from typing import Final from mypy_extensions import u8 @@ -2885,11 +2927,11 @@ L0: from mypy_extensions import mypyc_attr @mypyc_attr("allow_interpreted_subclasses", "invalid_arg") # E: "invalid_arg" is not a supported "mypyc_attr" \ - # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable" + # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable", "supports_weakref" class InvalidArg: pass @mypyc_attr(invalid_kwarg=True) # E: "invalid_kwarg" is not a supported "mypyc_attr" \ - # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable" + # N: supported keys: "allow_interpreted_subclasses", "free_list_len", "native_class", "serializable", "supports_weakref" class InvalidKwarg: pass @mypyc_attr(str()) # E: All "mypyc_attr" positional arguments must be string literals. diff --git a/mypyc/test-data/run-weakref.test b/mypyc/test-data/run-weakref.test index 0a0e180d635d..f64d7a12097a 100644 --- a/mypyc/test-data/run-weakref.test +++ b/mypyc/test-data/run-weakref.test @@ -15,6 +15,11 @@ class Object: _callback_called_cache = {"ref": False, "proxy": False} +@mypyc_attr(supports_weakref=True) +class NativeObject: + def some_meth(self) -> int: + return 1 + def test_weakref_ref() -> None: obj: Optional[Object] = Object() r = ref(obj) @@ -50,3 +55,10 @@ def test_weakref_proxy_with_callback() -> None: with assertRaises(ReferenceError): p.some_meth() assert _callback_called_cache["proxy"] is True + +def test_weakref_native_ref() -> None: + obj: Optional[NativeObject] = NativeObject() + r = ref(obj) + assert r() is obj + obj = None + assert r() is None