unix/modsocket: Use type-checking mp_obj_get_int.

MP_OBJ_SMALL_INT_VALUE would give erroneous results, such as assertion
failures in the coverage build and other oddities like:

    >>> s = socket.socket()
    >>> s.recv(3.14)
    MemoryError: memory allocation failed, allocating 4235896656 bytes

Signed-off-by: Jeff Epler <jepler@gmail.com>
This commit is contained in:
Jeff Epler
2025-08-03 10:11:49 -05:00
committed by Damien George
parent 6d640a15ab
commit e9da4c9c98
2 changed files with 37 additions and 18 deletions

View File

@@ -280,11 +280,11 @@ static MP_DEFINE_CONST_FUN_OBJ_1(socket_accept_obj, socket_accept);
// these would be thrown as exceptions.
static mp_obj_t socket_recv(size_t n_args, const mp_obj_t *args) {
mp_obj_socket_t *self = MP_OBJ_TO_PTR(args[0]);
int sz = MP_OBJ_SMALL_INT_VALUE(args[1]);
int sz = mp_obj_get_int(args[1]);
int flags = 0;
if (n_args > 2) {
flags = MP_OBJ_SMALL_INT_VALUE(args[2]);
flags = mp_obj_get_int(args[2]);
}
byte *buf = m_new(byte, sz);
@@ -298,11 +298,11 @@ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socket_recv_obj, 2, 3, socket_recv);
static mp_obj_t socket_recvfrom(size_t n_args, const mp_obj_t *args) {
mp_obj_socket_t *self = MP_OBJ_TO_PTR(args[0]);
int sz = MP_OBJ_SMALL_INT_VALUE(args[1]);
int sz = mp_obj_get_int(args[1]);
int flags = 0;
if (n_args > 2) {
flags = MP_OBJ_SMALL_INT_VALUE(args[2]);
flags = mp_obj_get_int(args[2]);
}
struct sockaddr_storage addr;
@@ -331,7 +331,7 @@ static mp_obj_t socket_send(size_t n_args, const mp_obj_t *args) {
int flags = 0;
if (n_args > 2) {
flags = MP_OBJ_SMALL_INT_VALUE(args[2]);
flags = mp_obj_get_int(args[2]);
}
mp_buffer_info_t bufinfo;
@@ -349,7 +349,7 @@ static mp_obj_t socket_sendto(size_t n_args, const mp_obj_t *args) {
mp_obj_t dst_addr = args[2];
if (n_args > 3) {
flags = MP_OBJ_SMALL_INT_VALUE(args[2]);
flags = mp_obj_get_int(args[2]);
dst_addr = args[3];
}
@@ -366,7 +366,7 @@ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(socket_sendto_obj, 3, 4, socket_sendt
static mp_obj_t socket_setsockopt(size_t n_args, const mp_obj_t *args) {
(void)n_args; // always 4
mp_obj_socket_t *self = MP_OBJ_TO_PTR(args[0]);
int level = MP_OBJ_SMALL_INT_VALUE(args[1]);
int level = mp_obj_get_int(args[1]);
int option = mp_obj_get_int(args[2]);
const void *optval;
@@ -478,14 +478,11 @@ static mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, siz
int proto = 0;
if (n_args > 0) {
assert(mp_obj_is_small_int(args[0]));
family = MP_OBJ_SMALL_INT_VALUE(args[0]);
family = mp_obj_get_int(args[0]);
if (n_args > 1) {
assert(mp_obj_is_small_int(args[1]));
type = MP_OBJ_SMALL_INT_VALUE(args[1]);
type = mp_obj_get_int(args[1]);
if (n_args > 2) {
assert(mp_obj_is_small_int(args[2]));
proto = MP_OBJ_SMALL_INT_VALUE(args[2]);
proto = mp_obj_get_int(args[2]);
}
}
}
@@ -582,7 +579,7 @@ static mp_obj_t mod_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) {
// getaddrinfo accepts port in string notation, so however
// it may seem stupid, we need to convert int to str
if (mp_obj_is_small_int(args[1])) {
unsigned port = (unsigned short)MP_OBJ_SMALL_INT_VALUE(args[1]);
unsigned port = (unsigned short)mp_obj_get_int(args[1]);
snprintf(buf, sizeof(buf), "%u", port);
serv = buf;
hints.ai_flags = AI_NUMERICSERV;
@@ -605,13 +602,13 @@ static mp_obj_t mod_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) {
}
if (n_args > 2) {
hints.ai_family = MP_OBJ_SMALL_INT_VALUE(args[2]);
hints.ai_family = mp_obj_get_int(args[2]);
if (n_args > 3) {
hints.ai_socktype = MP_OBJ_SMALL_INT_VALUE(args[3]);
hints.ai_socktype = mp_obj_get_int(args[3]);
if (n_args > 4) {
hints.ai_protocol = MP_OBJ_SMALL_INT_VALUE(args[4]);
hints.ai_protocol = mp_obj_get_int(args[4]);
if (n_args > 5) {
hints.ai_flags = MP_OBJ_SMALL_INT_VALUE(args[5]);
hints.ai_flags = mp_obj_get_int(args[5]);
}
}
}

View File

@@ -0,0 +1,22 @@
# Test passing in bad values to socket.socket constructor.
try:
import socket
except:
print("SKIP")
raise SystemExit
try:
s = socket.socket(None)
except TypeError:
print("TypeError")
try:
s = socket.socket(socket.AF_INET, None)
except TypeError:
print("TypeError")
try:
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, None)
except TypeError:
print("TypeError")