From 384cc627bc633af3bb092f173d447c33c1b71952 Mon Sep 17 00:00:00 2001 From: Damien George Date: Thu, 15 Jan 2026 11:42:04 +1100 Subject: [PATCH] py/persistentcode: Support saving functions with children. This adds support to `mp_raw_code_save_fun_to_bytes()` so that it can handle saving functions that have children. It does this by inspecting the MP_BC_MAKE_FUNCTION/etc opcodes to work out how many children there are, and creating a tree of simplified raw code information. Signed-off-by: Damien George --- py/persistentcode.c | 129 ++++++++++++++++++++++++++++++++------------ py/persistentcode.h | 2 +- 2 files changed, 97 insertions(+), 34 deletions(-) diff --git a/py/persistentcode.c b/py/persistentcode.c index d83386736b..2c7a425656 100644 --- a/py/persistentcode.c +++ b/py/persistentcode.c @@ -844,8 +844,28 @@ static mp_opcode_t mp_opcode_decode(const uint8_t *ip) { return op; } -mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, const uint8_t *bytecode) { - const uint8_t *fun_data = bytecode; +typedef struct _mp_raw_code_simplified_t { + const uint8_t *fun_data; + struct _mp_raw_code_simplified_t *children; + size_t fun_data_len; + size_t n_children; +} mp_raw_code_simplified_t; + +static void proto_fun_to_raw_code_simplified(const void *proto_fun, bit_vector_t *qstr_table_used, bit_vector_t *obj_table_used, mp_raw_code_simplified_t *rcs) { + const uint8_t *fun_data; + mp_raw_code_t **children; + if (mp_proto_fun_is_bytecode(proto_fun)) { + fun_data = proto_fun; + children = NULL; + } else { + const mp_raw_code_t *rc = proto_fun; + if (rc->kind != MP_CODE_BYTECODE) { + mp_raise_ValueError(MP_ERROR_TEXT("function must be bytecode")); + } + fun_data = rc->fun_data; + children = rc->children; + } + const uint8_t *fun_data_top = fun_data + gc_nbytes(fun_data); // Extract function information. @@ -853,6 +873,71 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons MP_BC_PRELUDE_SIG_DECODE(ip); MP_BC_PRELUDE_SIZE_DECODE(ip); + const byte *ip_names = ip; + mp_uint_t simple_name = mp_decode_uint(&ip_names); + bit_vector_set(qstr_table_used, simple_name); + for (size_t i = 0; i < n_pos_args + n_kwonly_args; ++i) { + mp_uint_t arg_name = mp_decode_uint(&ip_names); + bit_vector_set(qstr_table_used, arg_name); + } + + // Skip pass source code info and cell info. + // Then ip points to the start of the opcodes. + ip += n_info + n_cell; + + // Decode bytecode. + size_t n_children = 0; + while (ip < fun_data_top) { + mp_opcode_t op = mp_opcode_decode(ip); + if (op.opcode == MP_BC_BASE_RESERVED) { + // End of opcodes. + fun_data_top = ip; + } else if (op.format == MP_BC_FORMAT_QSTR) { + bit_vector_set(qstr_table_used, op.arg); + } else if (op.opcode == MP_BC_LOAD_CONST_OBJ) { + bit_vector_set(obj_table_used, op.arg); + } else if (op.opcode == MP_BC_MAKE_FUNCTION + || op.opcode == MP_BC_MAKE_FUNCTION_DEFARGS + || op.opcode == MP_BC_MAKE_CLOSURE + || op.opcode == MP_BC_MAKE_CLOSURE_DEFARGS) { + if ((mp_uint_t)op.arg + 1 > n_children) { + n_children = (mp_uint_t)op.arg + 1; + } + } + ip += op.size; + } + + rcs->fun_data = fun_data; + rcs->fun_data_len = fun_data_top - fun_data; + rcs->n_children = n_children; + rcs->children = NULL; + + if (n_children) { + rcs->children = m_new(mp_raw_code_simplified_t, n_children); + for (size_t i = 0; i < n_children; ++i) { + proto_fun_to_raw_code_simplified(children[i], qstr_table_used, obj_table_used, &rcs->children[i]); + } + } +} + +static void save_raw_code_simplified(mp_print_t *print, const mp_raw_code_simplified_t *rcs) { + // Save function kind and data length. + mp_print_uint(print, rcs->fun_data_len << 3 | (rcs->n_children != 0) << 2); + + // Save function code. + mp_print_bytes(print, rcs->fun_data, rcs->fun_data_len); + + // Save (and free) children. + if (rcs->n_children) { + mp_print_uint(print, rcs->n_children); + for (size_t i = 0; i < rcs->n_children; ++i) { + save_raw_code_simplified(print, &rcs->children[i]); + } + m_del(mp_raw_code_simplified_t, rcs->children, rcs->n_children); + } +} + +mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, mp_proto_fun_t proto_fun) { // Track the qstrs used by the function. bit_vector_t qstr_table_used; bit_vector_init(&qstr_table_used); @@ -861,33 +946,14 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons bit_vector_t obj_table_used; bit_vector_init(&obj_table_used); - const byte *ip_names = ip; - mp_uint_t simple_name = mp_decode_uint(&ip_names); - bit_vector_set(&qstr_table_used, simple_name); - for (size_t i = 0; i < n_pos_args + n_kwonly_args; ++i) { - mp_uint_t arg_name = mp_decode_uint(&ip_names); - bit_vector_set(&qstr_table_used, arg_name); - } + #if MICROPY_PY_BUILTINS_CODE >= MICROPY_PY_BUILTINS_CODE_FULL + // Make sure the filename appears in the qstr table. + bit_vector_set(&qstr_table_used, 0); + #endif - // Skip pass source code info and cell info. - // Then ip points to the start of the opcodes. - ip += n_info + n_cell; - - // Decode bytecode. - while (ip < fun_data_top) { - mp_opcode_t op = mp_opcode_decode(ip); - if (op.opcode == MP_BC_BASE_RESERVED) { - // End of opcodes. - fun_data_top = ip; - } else if (op.opcode == MP_BC_LOAD_CONST_OBJ) { - bit_vector_set(&obj_table_used, op.arg); - } else if (op.format == MP_BC_FORMAT_QSTR) { - bit_vector_set(&qstr_table_used, op.arg); - } - ip += op.size; - } - - mp_uint_t fun_data_len = fun_data_top - fun_data; + // Convert function into a simplified raw code tree. + mp_raw_code_simplified_t rcs; + proto_fun_to_raw_code_simplified(proto_fun, &qstr_table_used, &obj_table_used, &rcs); mp_print_t print; vstr_t vstr; @@ -922,11 +988,8 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons bit_vector_clear(&qstr_table_used); bit_vector_clear(&obj_table_used); - // Save function kind and data length. - mp_print_uint(&print, fun_data_len << 3); - - // Save function code. - mp_print_bytes(&print, fun_data, fun_data_len); + // Save the bytecode data (also free the simplified raw code tree at the same time). + save_raw_code_simplified(&print, &rcs); // Create and return bytes representing the .mpy data. return mp_obj_new_bytes_from_vstr(&vstr); diff --git a/py/persistentcode.h b/py/persistentcode.h index 1e0bbbd272..23cb0936f2 100644 --- a/py/persistentcode.h +++ b/py/persistentcode.h @@ -139,7 +139,7 @@ void mp_raw_code_load_file(qstr filename, mp_compiled_module_t *ctx); void mp_raw_code_save(mp_compiled_module_t *cm, mp_print_t *print); void mp_raw_code_save_file(mp_compiled_module_t *cm, qstr filename); -mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, const uint8_t *bytecode); +mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, mp_proto_fun_t proto_fun); void mp_native_relocate(void *reloc, uint8_t *text, uintptr_t reloc_text);