diff --git a/extmod/modlwip.c b/extmod/modlwip.c index e4bb720dba..bcec34bca3 100644 --- a/extmod/modlwip.c +++ b/extmod/modlwip.c @@ -348,8 +348,9 @@ typedef struct _lwip_socket_obj_t { #define STATE_LISTENING 1 #define STATE_CONNECTING 2 #define STATE_CONNECTED 3 - #define STATE_PEER_CLOSED 4 - #define STATE_ACTIVE_UDP 5 + #define STATE_ACTIVE_UDP 4 + #define STATE_PEER_CLOSED 5 // Values higher than this must also be closed by peer + #define STATE_PEER_RST_HANDLED 6 // Negative value is lwIP error int8_t state; } lwip_socket_obj_t; @@ -370,10 +371,10 @@ static struct tcp_pcb *volatile *lwip_socket_incoming_array(lwip_socket_obj_t *s } } -static void lwip_socket_free_incoming(lwip_socket_obj_t *socket) { +static void lwip_socket_free_incoming(lwip_socket_obj_t *socket, bool free_queued_stream_data) { if (socket->state != STATE_LISTENING) { if (socket->type == MOD_NETWORK_SOCK_STREAM) { - if (socket->incoming.tcp.pbuf != NULL) { + if (free_queued_stream_data && socket->incoming.tcp.pbuf != NULL) { pbuf_free(socket->incoming.tcp.pbuf); socket->incoming.tcp.pbuf = NULL; } @@ -399,6 +400,8 @@ static void lwip_socket_free_incoming(lwip_socket_obj_t *socket) { tcp_array[i] = NULL; } } + // This socket is now a non-listening stream, so clear the relevant state. + socket->incoming.tcp.pbuf = NULL; } } @@ -487,8 +490,9 @@ static void _lwip_udp_incoming(void *arg, struct udp_pcb *upcb, struct pbuf *p, static void _lwip_tcp_error(void *arg, err_t err) { lwip_socket_obj_t *socket = (lwip_socket_obj_t *)arg; - // Free any incoming buffers or connections that are stored - lwip_socket_free_incoming(socket); + // Free any incoming buffers or connections that are stored, but keep potential + // queued TCP data in case it's read later. Will be freed by MP_STREAM_CLOSE. + lwip_socket_free_incoming(socket, false); // Pass the error code back via the connection variable. socket->state = err; // If we got here, the lwIP stack either has deallocated or will deallocate the pcb. @@ -818,9 +822,6 @@ static mp_uint_t lwip_tcp_send(lwip_socket_obj_t *socket, const byte *buf, mp_ui // Helper function for recv/recvfrom to handle TCP packets static mp_uint_t lwip_tcp_receive(lwip_socket_obj_t *socket, byte *buf, mp_uint_t len, mp_int_t flags, int *_errno) { - // Check for any pending errors - STREAM_ERROR_CHECK(socket); - if (socket->state == STATE_LISTENING) { // original socket in listening state, not the accepted connection. *_errno = MP_ENOTCONN; @@ -828,10 +829,20 @@ static mp_uint_t lwip_tcp_receive(lwip_socket_obj_t *socket, byte *buf, mp_uint_ } if (socket->incoming.tcp.pbuf == NULL) { + // Check for any pending errors that should propagate out on socket read. + if (socket->state < 0) { + *_errno = error_lookup_table[-socket->state]; + if (*_errno == MP_ECONNRESET) { + socket->state = STATE_PEER_RST_HANDLED; + } else { + socket->state = _ERR_BADF; + } + return MP_STREAM_ERROR; + } // Non-blocking socket or flag if (socket->timeout == 0 || (flags & MSG_DONTWAIT)) { - if (socket->state == STATE_PEER_CLOSED) { + if (socket->state >= STATE_PEER_CLOSED) { return 0; } *_errno = MP_EAGAIN; @@ -847,7 +858,7 @@ static mp_uint_t lwip_tcp_receive(lwip_socket_obj_t *socket, byte *buf, mp_uint_ poll_sockets(); } - if (socket->state == STATE_PEER_CLOSED) { + if (socket->state >= STATE_PEER_CLOSED) { if (socket->incoming.tcp.pbuf == NULL) { // socket closed and no data left in buffer return 0; @@ -864,8 +875,6 @@ static mp_uint_t lwip_tcp_receive(lwip_socket_obj_t *socket, byte *buf, mp_uint_ MICROPY_PY_LWIP_ENTER - assert(socket->pcb.tcp != NULL); - struct pbuf *p = socket->incoming.tcp.pbuf; mp_uint_t remaining = p->len - socket->recv_offset; @@ -888,7 +897,9 @@ static mp_uint_t lwip_tcp_receive(lwip_socket_obj_t *socket, byte *buf, mp_uint_ } else { socket->recv_offset += len; } - tcp_recved(socket->pcb.tcp, len); + if (socket->pcb.tcp != NULL) { + tcp_recved(socket->pcb.tcp, len); + } } MICROPY_PY_LWIP_EXIT @@ -1292,8 +1303,6 @@ static mp_obj_t lwip_socket_recv_common(size_t n_args, const mp_obj_t *args, ip_ vstr_t vstr; mp_uint_t ret = 0; - lwip_socket_check_connected(socket); - vstr_init_len(&vstr, len); switch (socket->type) { @@ -1308,6 +1317,7 @@ static mp_obj_t lwip_socket_recv_common(size_t n_args, const mp_obj_t *args, ip_ #if MICROPY_PY_LWIP_SOCK_RAW case MOD_NETWORK_SOCK_RAW: #endif + lwip_socket_check_connected(socket); ret = lwip_raw_udp_receive(socket, (byte *)vstr.buf, len, flags, ip, port, &_errno); break; } @@ -1580,7 +1590,10 @@ static mp_uint_t lwip_socket_ioctl(mp_obj_t self_in, mp_uint_t request, uintptr_ } } else if (socket->type == MOD_NETWORK_SOCK_STREAM) { // For TCP sockets there is just one slot for incoming data - if (socket->incoming.tcp.pbuf != NULL) { + // The socket is also readable when in RST state + if (socket->incoming.tcp.pbuf != NULL + || socket->state == ERR_RST + || socket->state == STATE_PEER_RST_HANDLED) { ret |= MP_STREAM_POLL_RD; } } else { @@ -1619,6 +1632,8 @@ static mp_uint_t lwip_socket_ioctl(mp_obj_t self_in, mp_uint_t request, uintptr_ } else if (socket->state == ERR_RST) { // Socket was reset by peer, a write will return an error ret |= flags & MP_STREAM_POLL_WR; + ret |= MP_STREAM_POLL_ERR | MP_STREAM_POLL_HUP; + } else if (socket->state == STATE_PEER_RST_HANDLED) { ret |= MP_STREAM_POLL_HUP; } else if (socket->state == _ERR_BADF) { ret |= MP_STREAM_POLL_NVAL; @@ -1629,14 +1644,15 @@ static mp_uint_t lwip_socket_ioctl(mp_obj_t self_in, mp_uint_t request, uintptr_ } } else if (request == MP_STREAM_CLOSE) { + // Free any incoming buffers or connections that are stored + lwip_socket_free_incoming(socket, true); + if (socket->pcb.tcp == NULL) { + socket->state = _ERR_BADF; MICROPY_PY_LWIP_EXIT return 0; } - // Free any incoming buffers or connections that are stored - lwip_socket_free_incoming(socket); - switch (socket->type) { case MOD_NETWORK_SOCK_STREAM: { // Deregister callback (pcb.tcp is set to NULL below so must deregister now)