diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index bcaf0963af6f..78a784f1c571 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -54,6 +54,7 @@ is_str_rprimitive, is_tuple_rprimitive, is_uint8_rprimitive, + is_weakref_rprimitive, object_rprimitive, optional_value_type, ) @@ -665,6 +666,16 @@ def emit_cast( self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) self.emit_line("}") + elif is_weakref_rprimitive(typ): + if declare_dest: + self.emit_line(f"PyObject *{dest};") + check = "(PyWeakref_CheckRef({}))" + if likely: + check = f"(likely{check})" + self.emit_arg_check(src, dest, typ, check.format(src), optional) + self.emit_lines(f" {dest} = {src};", "else {") + self.emit_cast_error_handler(error, src, dest, typ, raise_exception) + self.emit_line("}") elif is_bytearray_rprimitive(typ): if declare_dest: self.emit_line(f"PyObject *{dest};") diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 20845a10177f..593acfc19840 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -516,6 +516,12 @@ def __hash__(self) -> int: # Python range object. range_rprimitive: Final = RPrimitive("builtins.range", is_unboxed=False, is_refcounted=True) +# Python weak reference object +weakref_rprimitive: Final = RPrimitive( + "weakref.ReferenceType", is_unboxed=False, is_refcounted=True +) + + KNOWN_NATIVE_TYPES: Final = { name: RPrimitive(name, is_unboxed=False, is_refcounted=True) for name in [ @@ -668,6 +674,10 @@ def is_immutable_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: ) +def is_weakref_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: + return isinstance(rtype, RPrimitive) and rtype.name == "weakref.ReferenceType" + + class TupleNameVisitor(RTypeVisitor[str]): """Produce a tuple name based on the concrete representations of types.""" diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 550dc6e42c9e..1fb307372648 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -48,6 +48,7 @@ str_rprimitive, tuple_rprimitive, uint8_rprimitive, + weakref_rprimitive, ) @@ -106,6 +107,8 @@ def type_to_rtype(self, typ: Type | None) -> RType: return tuple_rprimitive # Varying-length tuple elif typ.type.fullname == "builtins.range": return range_rprimitive + elif typ.type.fullname == "weakref.ReferenceType": + return weakref_rprimitive elif typ.type in self.type_to_ir: inst = RInstance(self.type_to_ir[typ.type]) # Treat protocols as Union[protocol, object], so that we can do fast diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index d0934914dfe9..55544f41ec3e 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -76,6 +76,7 @@ is_str_rprimitive, is_tagged, is_uint8_rprimitive, + is_weakref_rprimitive, list_rprimitive, object_rprimitive, set_rprimitive, @@ -136,6 +137,7 @@ str_range_check_op, ) from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op +from mypyc.primitives.weakref_ops import weakref_deref_op # Specializers are attempted before compiling the arguments to the # function. Specializers can return None to indicate that they failed @@ -186,6 +188,8 @@ def apply_function_specialization( builder: IRBuilder, expr: CallExpr, callee: RefExpr ) -> Value | None: """Invoke the Specializer callback for a function if one has been registered""" + if is_weakref_rprimitive(builder.node_type(callee)) and len(expr.args) == 0: + return builder.call_c(weakref_deref_op, [builder.accept(expr.callee)], expr.line) return _apply_specialization(builder, expr, callee, callee.fullname) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index af5247e64393..52c78cb48e3d 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -956,6 +956,7 @@ PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyOb PyObject *CPy_GetAIter(PyObject *obj); PyObject *CPy_GetANext(PyObject *aiter); +PyObject *CPyWeakref_GetRef(PyObject *ref); void CPy_SetTypeAliasTypeComputeFunction(PyObject *alias, PyObject *compute_value); void CPyTrace_LogEvent(const char *location, const char *line, const char *op, const char *details); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 64b4ff67b942..ca098ead26a8 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -1133,3 +1133,18 @@ void CPy_SetImmortal(PyObject *obj) { } #endif + + +PyObject *CPyWeakref_GetRef(PyObject *ref) +{ + PyObject *obj = NULL; + int success = PyWeakref_GetRef(ref, &obj); + if (success == -1) { + return NULL; + } else if (obj == NULL) { + Py_INCREF(Py_None); + return Py_None; + } else { + return obj; + } +} diff --git a/mypyc/primitives/weakref_ops.py b/mypyc/primitives/weakref_ops.py index 21379d3b2c82..b5fb83641aae 100644 --- a/mypyc/primitives/weakref_ops.py +++ b/mypyc/primitives/weakref_ops.py @@ -1,13 +1,13 @@ from mypyc.ir.ops import ERR_MAGIC -from mypyc.ir.rtypes import object_rprimitive, pointer_rprimitive -from mypyc.primitives.registry import function_op +from mypyc.ir.rtypes import object_rprimitive, pointer_rprimitive, weakref_rprimitive +from mypyc.primitives.registry import custom_op, function_op # Weakref operations new_ref_op = function_op( name="weakref.ReferenceType", arg_types=[object_rprimitive], - return_type=object_rprimitive, + return_type=weakref_rprimitive, c_function_name="PyWeakref_NewRef", extra_int_constants=[(0, pointer_rprimitive)], error_kind=ERR_MAGIC, @@ -16,7 +16,7 @@ new_ref__with_callback_op = function_op( name="weakref.ReferenceType", arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, + return_type=weakref_rprimitive, c_function_name="PyWeakref_NewRef", error_kind=ERR_MAGIC, ) @@ -38,3 +38,11 @@ c_function_name="PyWeakref_NewProxy", error_kind=ERR_MAGIC, ) + +# TODO: generate specialized versions of this that return the proper rtype +weakref_deref_op = custom_op( + arg_types=[weakref_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyWeakref_GetRef", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-weakref.test b/mypyc/test-data/irbuild-weakref.test index 2180b1e747aa..afe87aad6fed 100644 --- a/mypyc/test-data/irbuild-weakref.test +++ b/mypyc/test-data/irbuild-weakref.test @@ -2,53 +2,73 @@ import weakref from typing import Any, Callable def f(x: object) -> object: - return weakref.ref(x) + ref = weakref.ref(x) + return ref() [out] def f(x): - x, r0 :: object + x :: object + r0, ref :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, 0) - return r0 + ref = r0 + r1 = CPyWeakref_GetRef(ref) + return r1 [case testWeakrefRefCallback] import weakref from typing import Any, Callable def f(x: object, cb: Callable[[object], Any]) -> object: - return weakref.ref(x, cb) + ref = weakref.ref(x, cb) + return ref() [out] def f(x, cb): - x, cb, r0 :: object + x, cb :: object + r0, ref :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, cb) - return r0 + ref = r0 + r1 = CPyWeakref_GetRef(ref) + return r1 [case testFromWeakrefRef] from typing import Any, Callable from weakref import ref def f(x: object) -> object: - return ref(x) + r = ref(x) + return r() [out] def f(x): - x, r0 :: object + x :: object + r0, r :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, 0) - return r0 + r = r0 + r1 = CPyWeakref_GetRef(r) + return r1 [case testFromWeakrefRefCallback] from typing import Any, Callable from weakref import ref def f(x: object, cb: Callable[[object], Any]) -> object: - return ref(x, cb) + r = ref(x, cb) + return r() [out] def f(x, cb): - x, cb, r0 :: object + x, cb :: object + r0, r :: weakref.ReferenceType + r1 :: object L0: r0 = PyWeakref_NewRef(x, cb) - return r0 + r = r0 + r1 = CPyWeakref_GetRef(r) + return r1 [case testWeakrefProxy] import weakref