diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index d0934914dfe9..c85f45508089 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -113,7 +113,12 @@ ) from mypyc.primitives.float_ops import isinstance_float from mypyc.primitives.generic_ops import generic_setattr, setup_object -from mypyc.primitives.int_ops import isinstance_int +from mypyc.primitives.int_ops import ( + int_to_big_endian_op, + int_to_bytes_op, + int_to_little_endian_op, + isinstance_int, +) from mypyc.primitives.librt_strings_ops import ( bytes_writer_adjust_index_op, bytes_writer_get_item_unsafe_op, @@ -1242,6 +1247,77 @@ def translate_object_setattr(builder: IRBuilder, expr: CallExpr, callee: RefExpr return builder.call_c(generic_setattr, [self_reg, name_reg, value], expr.line) +@specialize_function("to_bytes", int_rprimitive) +def specialize_int_to_bytes(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + # int.to_bytes(length, byteorder, signed=False) + if any(kind not in (ARG_POS, ARG_NAMED) for kind in expr.arg_kinds): + return None + if not isinstance(callee, MemberExpr): + return None + length_expr: Expression | None = None + byteorder_expr: Expression | None = None + signed_expr: Expression | None = None + positional_index = 0 + for name, arg in zip(expr.arg_names, expr.args): + if name is None: + if positional_index == 0: + length_expr = arg + elif positional_index == 1: + byteorder_expr = arg + elif positional_index == 2: + signed_expr = arg + else: + return None + positional_index += 1 + elif name == "length": + if length_expr is not None: + return None + length_expr = arg + elif name == "byteorder": + if byteorder_expr is not None: + return None + byteorder_expr = arg + elif name == "signed": + if signed_expr is not None: + return None + signed_expr = arg + else: + return None + if length_expr is None or byteorder_expr is None: + return None + + signed_is_bool = True + if signed_expr is not None: + signed_is_bool = is_bool_rprimitive(builder.node_type(signed_expr)) + if not ( + is_int_rprimitive(builder.node_type(length_expr)) + and is_str_rprimitive(builder.node_type(byteorder_expr)) + and signed_is_bool + ): + return None + + self_arg = builder.accept(callee.expr) + length_arg = builder.accept(length_expr) + if signed_expr is None: + signed_arg = builder.false() + else: + signed_arg = builder.accept(signed_expr) + if isinstance(byteorder_expr, StrExpr): + if byteorder_expr.value == "little": + return builder.call_c( + int_to_little_endian_op, [self_arg, length_arg, signed_arg], expr.line + ) + elif byteorder_expr.value == "big": + return builder.call_c( + int_to_big_endian_op, [self_arg, length_arg, signed_arg], expr.line + ) + # Fallback to generic primitive op + byteorder_arg = builder.accept(byteorder_expr) + return builder.call_c( + int_to_bytes_op, [self_arg, length_arg, byteorder_arg, signed_arg], expr.line + ) + + def translate_getitem_with_bounds_check( builder: IRBuilder, base_expr: Expression, diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index af5247e64393..b2a48bb29ecb 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -149,6 +149,9 @@ CPyTagged CPyTagged_BitwiseLongOp_(CPyTagged a, CPyTagged b, char op); CPyTagged CPyTagged_Rshift_(CPyTagged left, CPyTagged right); CPyTagged CPyTagged_Lshift_(CPyTagged left, CPyTagged right); CPyTagged CPyTagged_BitLength(CPyTagged self); +PyObject *CPyTagged_ToBytes(CPyTagged self, Py_ssize_t length, PyObject *byteorder, int signed_flag); +PyObject *CPyTagged_ToBigEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag); +PyObject *CPyTagged_ToLittleEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag); PyObject *CPyTagged_Str(CPyTagged n); CPyTagged CPyTagged_FromFloat(double f); diff --git a/mypyc/lib-rt/int_ops.c b/mypyc/lib-rt/int_ops.c index 04538ab832fc..868b2320cbe4 100644 --- a/mypyc/lib-rt/int_ops.c +++ b/mypyc/lib-rt/int_ops.c @@ -597,6 +597,68 @@ double CPyTagged_TrueDivide(CPyTagged x, CPyTagged y) { return 1.0; } +static PyObject *CPyLong_ToBytes(PyObject *v, Py_ssize_t length, int little_endian, int signed_flag) { + // This is a wrapper for PyLong_AsByteArray and PyBytes_FromStringAndSize + PyObject *result = PyBytes_FromStringAndSize(NULL, length); + if (!result) { + return NULL; + } + unsigned char *bytes = (unsigned char *)PyBytes_AS_STRING(result); +#if PY_VERSION_HEX >= 0x030D0000 // 3.13.0 + int res = _PyLong_AsByteArray((PyLongObject *)v, bytes, length, little_endian, signed_flag, 1); +#else + int res = _PyLong_AsByteArray((PyLongObject *)v, bytes, length, little_endian, signed_flag); +#endif + if (res < 0) { + Py_DECREF(result); + return NULL; + } + return result; +} + +// int.to_bytes(length, byteorder, signed=False) +PyObject *CPyTagged_ToBytes(CPyTagged self, Py_ssize_t length, PyObject *byteorder, int signed_flag) { + PyObject *pyint = CPyTagged_AsObject(self); + if (!PyUnicode_Check(byteorder)) { + Py_DECREF(pyint); + PyErr_SetString(PyExc_TypeError, "byteorder must be str"); + return NULL; + } + const char *order = PyUnicode_AsUTF8(byteorder); + if (!order) { + Py_DECREF(pyint); + return NULL; + } + int little_endian; + if (strcmp(order, "big") == 0) { + little_endian = 0; + } else if (strcmp(order, "little") == 0) { + little_endian = 1; + } else { + PyErr_SetString(PyExc_ValueError, "byteorder must be either 'little' or 'big'"); + return NULL; + } + PyObject *result = CPyLong_ToBytes(pyint, length, little_endian, signed_flag); + Py_DECREF(pyint); + return result; +} + +// int.to_bytes(length, byteorder="little", signed=False) +PyObject *CPyTagged_ToLittleEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag) { + PyObject *pyint = CPyTagged_AsObject(self); + PyObject *result = CPyLong_ToBytes(pyint, length, 1, signed_flag); + Py_DECREF(pyint); + return result; +} + +// int.to_bytes(length, "big", signed=False) +PyObject *CPyTagged_ToBigEndianBytes(CPyTagged self, Py_ssize_t length, int signed_flag) { + PyObject *pyint = CPyTagged_AsObject(self); + PyObject *result = CPyLong_ToBytes(pyint, length, 0, signed_flag); + Py_DECREF(pyint); + return result; +} + // int.bit_length() CPyTagged CPyTagged_BitLength(CPyTagged self) { // Handle zero diff --git a/mypyc/primitives/int_ops.py b/mypyc/primitives/int_ops.py index 8f43140dd255..b40c92be388a 100644 --- a/mypyc/primitives/int_ops.py +++ b/mypyc/primitives/int_ops.py @@ -21,6 +21,7 @@ RType, bit_rprimitive, bool_rprimitive, + bytes_rprimitive, c_pyssize_t_rprimitive, float_rprimitive, int16_rprimitive, @@ -313,6 +314,34 @@ def int_unary_op(name: str, c_function_name: str) -> PrimitiveDescription: error_kind=ERR_NEVER, ) +# specialized custom_op cases for int.to_bytes + +# int.to_bytes(length, "big") +# int.to_bytes(length, "big", signed=...) +int_to_big_endian_op = custom_op( + arg_types=[int_rprimitive, c_pyssize_t_rprimitive, bool_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyTagged_ToBigEndianBytes", + error_kind=ERR_MAGIC, +) + +# int.to_bytes(length, "little") +# int.to_bytes(length, "little", signed=...) +int_to_little_endian_op = custom_op( + arg_types=[int_rprimitive, c_pyssize_t_rprimitive, bool_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyTagged_ToLittleEndianBytes", + error_kind=ERR_MAGIC, +) + +# int.to_bytes(length, byteorder, signed=...) +int_to_bytes_op = custom_op( + arg_types=[int_rprimitive, c_pyssize_t_rprimitive, str_rprimitive, bool_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyTagged_ToBytes", + error_kind=ERR_MAGIC, +) + # int.bit_length() method_op( name="bit_length", diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 5f7c1d49852d..4f8a29672911 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -87,6 +87,7 @@ def __lt__(self, n: int) -> bool: pass def __gt__(self, n: int) -> bool: pass def __le__(self, n: int) -> bool: pass def __ge__(self, n: int) -> bool: pass + def to_bytes(self, length: int, order: str, *, signed: bool = False) -> bytes: pass def bit_length(self) -> int: pass class str: diff --git a/mypyc/test-data/irbuild-int.test b/mypyc/test-data/irbuild-int.test index 184c66fafb7c..10991f26c408 100644 --- a/mypyc/test-data/irbuild-int.test +++ b/mypyc/test-data/irbuild-int.test @@ -211,6 +211,35 @@ L0: x = r0 return x +[case testIntToBytes] +def f(x: int) -> bytes: + return x.to_bytes(2, "big") +def g(x: int) -> bytes: + return x.to_bytes(4, "little", signed=True) +def h(x: int, byteorder: str) -> bytes: + return x.to_bytes(8, byteorder) + +[out] +def f(x): + x :: int + r0 :: bytes +L0: + r0 = CPyTagged_ToBigEndianBytes(x, 2, 0) + return r0 +def g(x): + x :: int + r0 :: bytes +L0: + r0 = CPyTagged_ToLittleEndianBytes(x, 4, 1) + return r0 +def h(x, byteorder): + x :: int + byteorder :: str + r0 :: bytes +L0: + r0 = CPyTagged_ToBytes(x, 8, byteorder, 0) + return r0 + [case testIntBitLength] def f(x: int) -> int: return x.bit_length() diff --git a/mypyc/test-data/run-integers.test b/mypyc/test-data/run-integers.test index c02f7d808883..cbcb0d43fc9d 100644 --- a/mypyc/test-data/run-integers.test +++ b/mypyc/test-data/run-integers.test @@ -573,6 +573,27 @@ class subc(int): class int: pass +[case testIntToBytes] +from testutil import assertRaises +def to_bytes(n: int, length: int, byteorder: str, signed: bool = False) -> bytes: + return n.to_bytes(length, byteorder, signed=signed) +def test_to_bytes() -> None: + assert to_bytes(255, 2, "big") == b'\x00\xff', to_bytes(255, 2, "big") + assert to_bytes(255, 2, "little") == b'\xff\x00', to_bytes(255, 2, "little") + assert to_bytes(-1, 2, "big", True) == b'\xff\xff', to_bytes(-1, 2, "big", True) + assert to_bytes(0, 1, "big") == b'\x00', to_bytes(0, 1, "big") + # test with a value that does not fit in 64 bits + assert to_bytes(10**30, 16, "big") == b'\x00\x00\x00\x0c\x9f,\x9c\xd0Ft\xed\xea@\x00\x00\x00', to_bytes(10**30, 16, "big") + # unsigned, too large for 1 byte + with assertRaises(OverflowError): + to_bytes(256, 1, "big") + # signed, too small for 1 byte + with assertRaises(OverflowError): + to_bytes(-129, 1, "big", True) + # signed, too large for 1 byte + with assertRaises(OverflowError): + to_bytes(128, 1, "big", True) + [case testBitLength] def bit_length(n: int) -> int: return n.bit_length()