py/persistentcode: Support saving functions with children.

This adds support to `mp_raw_code_save_fun_to_bytes()` so that it can
handle saving functions that have children.  It does this by inspecting
the MP_BC_MAKE_FUNCTION/etc opcodes to work out how many children there
are, and creating a tree of simplified raw code information.

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George
2026-01-15 11:42:04 +11:00
parent 03f50d39da
commit 384cc627bc
2 changed files with 97 additions and 34 deletions

View File

@@ -844,8 +844,28 @@ static mp_opcode_t mp_opcode_decode(const uint8_t *ip) {
return op;
}
mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, const uint8_t *bytecode) {
const uint8_t *fun_data = bytecode;
typedef struct _mp_raw_code_simplified_t {
const uint8_t *fun_data;
struct _mp_raw_code_simplified_t *children;
size_t fun_data_len;
size_t n_children;
} mp_raw_code_simplified_t;
static void proto_fun_to_raw_code_simplified(const void *proto_fun, bit_vector_t *qstr_table_used, bit_vector_t *obj_table_used, mp_raw_code_simplified_t *rcs) {
const uint8_t *fun_data;
mp_raw_code_t **children;
if (mp_proto_fun_is_bytecode(proto_fun)) {
fun_data = proto_fun;
children = NULL;
} else {
const mp_raw_code_t *rc = proto_fun;
if (rc->kind != MP_CODE_BYTECODE) {
mp_raise_ValueError(MP_ERROR_TEXT("function must be bytecode"));
}
fun_data = rc->fun_data;
children = rc->children;
}
const uint8_t *fun_data_top = fun_data + gc_nbytes(fun_data);
// Extract function information.
@@ -853,6 +873,71 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons
MP_BC_PRELUDE_SIG_DECODE(ip);
MP_BC_PRELUDE_SIZE_DECODE(ip);
const byte *ip_names = ip;
mp_uint_t simple_name = mp_decode_uint(&ip_names);
bit_vector_set(qstr_table_used, simple_name);
for (size_t i = 0; i < n_pos_args + n_kwonly_args; ++i) {
mp_uint_t arg_name = mp_decode_uint(&ip_names);
bit_vector_set(qstr_table_used, arg_name);
}
// Skip pass source code info and cell info.
// Then ip points to the start of the opcodes.
ip += n_info + n_cell;
// Decode bytecode.
size_t n_children = 0;
while (ip < fun_data_top) {
mp_opcode_t op = mp_opcode_decode(ip);
if (op.opcode == MP_BC_BASE_RESERVED) {
// End of opcodes.
fun_data_top = ip;
} else if (op.format == MP_BC_FORMAT_QSTR) {
bit_vector_set(qstr_table_used, op.arg);
} else if (op.opcode == MP_BC_LOAD_CONST_OBJ) {
bit_vector_set(obj_table_used, op.arg);
} else if (op.opcode == MP_BC_MAKE_FUNCTION
|| op.opcode == MP_BC_MAKE_FUNCTION_DEFARGS
|| op.opcode == MP_BC_MAKE_CLOSURE
|| op.opcode == MP_BC_MAKE_CLOSURE_DEFARGS) {
if ((mp_uint_t)op.arg + 1 > n_children) {
n_children = (mp_uint_t)op.arg + 1;
}
}
ip += op.size;
}
rcs->fun_data = fun_data;
rcs->fun_data_len = fun_data_top - fun_data;
rcs->n_children = n_children;
rcs->children = NULL;
if (n_children) {
rcs->children = m_new(mp_raw_code_simplified_t, n_children);
for (size_t i = 0; i < n_children; ++i) {
proto_fun_to_raw_code_simplified(children[i], qstr_table_used, obj_table_used, &rcs->children[i]);
}
}
}
static void save_raw_code_simplified(mp_print_t *print, const mp_raw_code_simplified_t *rcs) {
// Save function kind and data length.
mp_print_uint(print, rcs->fun_data_len << 3 | (rcs->n_children != 0) << 2);
// Save function code.
mp_print_bytes(print, rcs->fun_data, rcs->fun_data_len);
// Save (and free) children.
if (rcs->n_children) {
mp_print_uint(print, rcs->n_children);
for (size_t i = 0; i < rcs->n_children; ++i) {
save_raw_code_simplified(print, &rcs->children[i]);
}
m_del(mp_raw_code_simplified_t, rcs->children, rcs->n_children);
}
}
mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, mp_proto_fun_t proto_fun) {
// Track the qstrs used by the function.
bit_vector_t qstr_table_used;
bit_vector_init(&qstr_table_used);
@@ -861,33 +946,14 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons
bit_vector_t obj_table_used;
bit_vector_init(&obj_table_used);
const byte *ip_names = ip;
mp_uint_t simple_name = mp_decode_uint(&ip_names);
bit_vector_set(&qstr_table_used, simple_name);
for (size_t i = 0; i < n_pos_args + n_kwonly_args; ++i) {
mp_uint_t arg_name = mp_decode_uint(&ip_names);
bit_vector_set(&qstr_table_used, arg_name);
}
#if MICROPY_PY_BUILTINS_CODE >= MICROPY_PY_BUILTINS_CODE_FULL
// Make sure the filename appears in the qstr table.
bit_vector_set(&qstr_table_used, 0);
#endif
// Skip pass source code info and cell info.
// Then ip points to the start of the opcodes.
ip += n_info + n_cell;
// Decode bytecode.
while (ip < fun_data_top) {
mp_opcode_t op = mp_opcode_decode(ip);
if (op.opcode == MP_BC_BASE_RESERVED) {
// End of opcodes.
fun_data_top = ip;
} else if (op.opcode == MP_BC_LOAD_CONST_OBJ) {
bit_vector_set(&obj_table_used, op.arg);
} else if (op.format == MP_BC_FORMAT_QSTR) {
bit_vector_set(&qstr_table_used, op.arg);
}
ip += op.size;
}
mp_uint_t fun_data_len = fun_data_top - fun_data;
// Convert function into a simplified raw code tree.
mp_raw_code_simplified_t rcs;
proto_fun_to_raw_code_simplified(proto_fun, &qstr_table_used, &obj_table_used, &rcs);
mp_print_t print;
vstr_t vstr;
@@ -922,11 +988,8 @@ mp_obj_t mp_raw_code_save_fun_to_bytes(const mp_module_constants_t *consts, cons
bit_vector_clear(&qstr_table_used);
bit_vector_clear(&obj_table_used);
// Save function kind and data length.
mp_print_uint(&print, fun_data_len << 3);
// Save function code.
mp_print_bytes(&print, fun_data, fun_data_len);
// Save the bytecode data (also free the simplified raw code tree at the same time).
save_raw_code_simplified(&print, &rcs);
// Create and return bytes representing the .mpy data.
return mp_obj_new_bytes_from_vstr(&vstr);