tests/net_hosted: Convert connect-nonblock-xfer test to use unittest.

This allows it to run parts of the test on esp8266 (or any target using
axTLS).

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George
2024-11-13 12:44:10 +11:00
parent f62df1a2c2
commit c7c3ffa45f
2 changed files with 62 additions and 86 deletions

View File

@@ -1,15 +1,14 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS # test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing # and that an immediate write/send/read/recv does the right thing
import unittest
import errno import errno
import select import select
import socket import socket
import ssl import ssl
# only mbedTLS supports non-blocking mode # only mbedTLS supports non-blocking mode
if not hasattr(ssl, "MBEDTLS_VERSION"): ssl_supports_nonblocking = hasattr(ssl, "MBEDTLS_VERSION")
print("SKIP")
raise SystemExit
# get the name of an errno error code # get the name of an errno error code
@@ -24,34 +23,43 @@ def errno_name(er):
# do_connect establishes the socket and wraps it if tls is True. # do_connect establishes the socket and wraps it if tls is True.
# If handshake is true, the initial connect (and TLS handshake) is # If handshake is true, the initial connect (and TLS handshake) is
# allowed to be performed before returning. # allowed to be performed before returning.
def do_connect(peer_addr, tls, handshake): def do_connect(self, peer_addr, tls, handshake):
s = socket.socket() s = socket.socket()
s.setblocking(False) s.setblocking(False)
try: try:
# print("Connecting to", peer_addr) print("Connecting to", peer_addr)
s.connect(peer_addr) s.connect(peer_addr)
self.fail()
except OSError as er: except OSError as er:
print("connect:", errno_name(er.errno)) print("connect:", errno_name(er.errno))
self.assertEqual(er.errno, errno.EINPROGRESS)
# wrap with ssl/tls if desired # wrap with ssl/tls if desired
if tls: if tls:
print("wrap socket")
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try: s = ssl_context.wrap_socket(s, do_handshake_on_connect=handshake)
s = ssl_context.wrap_socket(s, do_handshake_on_connect=handshake)
print("wrap ok: True")
except Exception as e:
print("wrap er:", e)
return s return s
# poll a socket and print out the result # poll a socket and check the result
def poll(s): def poll(self, s, expect_writable):
poller = select.poll() poller = select.poll()
poller.register(s) poller.register(s)
print("poll: ", poller.poll(0)) result = poller.poll(0)
print("poll:", result)
if expect_writable:
self.assertEqual(len(result), 1)
self.assertEqual(result[0][1], select.POLLOUT)
else:
self.assertEqual(result, [])
# test runs the test against a specific peer address. # do_test runs the test against a specific peer address.
def test(peer_addr, tls, handshake): def do_test(self, peer_addr, tls, handshake):
print()
# MicroPython plain and TLS sockets have read/write # MicroPython plain and TLS sockets have read/write
hasRW = True hasRW = True
@@ -62,54 +70,66 @@ def test(peer_addr, tls, handshake):
# connect + send # connect + send
# non-blocking send should raise EAGAIN # non-blocking send should raise EAGAIN
if hasSR: if hasSR:
s = do_connect(peer_addr, tls, handshake) s = do_connect(self, peer_addr, tls, handshake)
poll(s) poll(self, s, False)
try: with self.assertRaises(OSError) as ctx:
ret = s.send(b"1234") ret = s.send(b"1234")
print("send ok:", ret) # shouldn't get here print("send error:", errno_name(ctx.exception.errno))
except OSError as er: self.assertEqual(ctx.exception.errno, errno.EAGAIN)
print("send er:", errno_name(er.errno))
s.close() s.close()
# connect + write # connect + write
# non-blocking write should return None # non-blocking write should return None
if hasRW: if hasRW:
s = do_connect(peer_addr, tls, handshake) s = do_connect(self, peer_addr, tls, handshake)
poll(s) poll(self, s, tls and handshake)
ret = s.write(b"1234") ret = s.write(b"1234")
print("write: ", ret) print("write:", ret)
if tls and handshake:
self.assertEqual(ret, 4)
else:
self.assertIsNone(ret)
s.close() s.close()
# connect + recv # connect + recv
# non-blocking recv should raise EAGAIN # non-blocking recv should raise EAGAIN
if hasSR: if hasSR:
s = do_connect(peer_addr, tls, handshake) s = do_connect(self, peer_addr, tls, handshake)
poll(s) poll(self, s, False)
try: with self.assertRaises(OSError) as ctx:
ret = s.recv(10) ret = s.recv(10)
print("recv ok:", ret) # shouldn't get here print("recv error:", errno_name(ctx.exception.errno))
except OSError as er: self.assertEqual(ctx.exception.errno, errno.EAGAIN)
print("recv er:", errno_name(er.errno))
s.close() s.close()
# connect + read # connect + read
# non-blocking read should return None # non-blocking read should return None
if hasRW: if hasRW:
s = do_connect(peer_addr, tls, handshake) s = do_connect(self, peer_addr, tls, handshake)
poll(s) poll(self, s, tls and handshake)
ret = s.read(10) ret = s.read(10)
print("read: ", ret) print("read:", ret)
self.assertIsNone(ret)
s.close() s.close()
if __name__ == "__main__": class Test(unittest.TestCase):
# these tests use a non-existent test IP address, this way the connect takes forever and # these tests use a non-existent test IP address, this way the connect takes forever and
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737) # we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
print("--- Plain sockets to nowhere ---") def test_plain_sockets_to_nowhere(self):
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False) do_test(self, socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
print("--- SSL sockets to nowhere ---")
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False) @unittest.skipIf(not ssl_supports_nonblocking, "SSL doesn't support non-blocking")
print("--- Plain sockets ---") def test_ssl_sockets_to_nowhere(self):
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, False) do_test(self, socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
print("--- SSL sockets ---")
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True) def test_plain_sockets(self):
do_test(self, socket.getaddrinfo("micropython.org", 80)[0][-1], False, False)
@unittest.skipIf(not ssl_supports_nonblocking, "SSL doesn't support non-blocking")
def test_ssl_sockets(self):
do_test(self, socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,44 +0,0 @@
--- Plain sockets to nowhere ---
connect: EINPROGRESS
poll: []
send er: EAGAIN
connect: EINPROGRESS
poll: []
write: None
connect: EINPROGRESS
poll: []
recv er: EAGAIN
connect: EINPROGRESS
poll: []
read: None
--- SSL sockets to nowhere ---
connect: EINPROGRESS
wrap ok: True
poll: []
write: None
connect: EINPROGRESS
wrap ok: True
poll: []
read: None
--- Plain sockets ---
connect: EINPROGRESS
poll: []
send er: EAGAIN
connect: EINPROGRESS
poll: []
write: None
connect: EINPROGRESS
poll: []
recv er: EAGAIN
connect: EINPROGRESS
poll: []
read: None
--- SSL sockets ---
connect: EINPROGRESS
wrap ok: True
poll: [(<SSLSocket>, 4)]
write: 4
connect: EINPROGRESS
wrap ok: True
poll: [(<SSLSocket>, 4)]
read: None