first comit
This commit is contained in:
159
venv/lib/python3.10/site-packages/zmq/tests/test_zmqstream.py
Normal file
159
venv/lib/python3.10/site-packages/zmq/tests/test_zmqstream.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
try:
|
||||
import tornado
|
||||
|
||||
from zmq.eventloop import zmqstream
|
||||
except ImportError:
|
||||
tornado = None # type: ignore
|
||||
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("io_loop")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def push_pull(socket):
|
||||
push = zmqstream.ZMQStream(socket(zmq.PUSH))
|
||||
pull = zmqstream.ZMQStream(socket(zmq.PULL))
|
||||
port = push.bind_to_random_port('tcp://127.0.0.1')
|
||||
pull.connect('tcp://127.0.0.1:%i' % port)
|
||||
return (push, pull)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def push(push_pull):
|
||||
push, pull = push_pull
|
||||
return push
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pull(push_pull):
|
||||
push, pull = push_pull
|
||||
return pull
|
||||
|
||||
|
||||
async def test_callable_check(pull):
|
||||
"""Ensure callable check works."""
|
||||
|
||||
pull.on_send(lambda *args: None)
|
||||
pull.on_recv(lambda *args: None)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_recv(1)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_send(1)
|
||||
with pytest.raises(AssertionError):
|
||||
pull.on_recv(zmq)
|
||||
|
||||
|
||||
async def test_on_recv_basic(push, pull):
|
||||
sent = [b'basic']
|
||||
push.send_multipart(sent)
|
||||
f = asyncio.Future()
|
||||
|
||||
def callback(msg):
|
||||
f.set_result(msg)
|
||||
|
||||
pull.on_recv(callback)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_wake(push, pull):
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
pull.on_recv(f.set_result)
|
||||
await asyncio.sleep(0.5)
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_async(push, pull):
|
||||
if tornado.version_info < (5,):
|
||||
pytest.skip()
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
|
||||
async def callback(msg):
|
||||
await asyncio.sleep(0.1)
|
||||
f.set_result(msg)
|
||||
|
||||
pull.on_recv(callback)
|
||||
await asyncio.sleep(0.5)
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
|
||||
|
||||
async def test_on_recv_async_error(push, pull, caplog):
|
||||
sent = [b'wake']
|
||||
|
||||
f = asyncio.Future()
|
||||
|
||||
async def callback(msg):
|
||||
f.set_result(msg)
|
||||
1 / 0
|
||||
|
||||
pull.on_recv(callback)
|
||||
await asyncio.sleep(0.1)
|
||||
with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
|
||||
push.send_multipart(sent)
|
||||
recvd = await asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == sent
|
||||
# logging error takes a tick later
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
messages = [
|
||||
x.message
|
||||
for x in caplog.get_records("call")
|
||||
if x.name == zmqstream.gen_log.name
|
||||
]
|
||||
assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)
|
||||
|
||||
|
||||
async def test_shadow_socket(context):
|
||||
with context.socket(zmq.PUSH, socket_class=zmq.asyncio.Socket) as socket:
|
||||
with pytest.warns(RuntimeWarning):
|
||||
stream = zmqstream.ZMQStream(socket)
|
||||
assert type(stream.socket) is zmq.Socket
|
||||
assert stream.socket.underlying == socket.underlying
|
||||
stream.close()
|
||||
|
||||
|
||||
async def test_shadow_socket_close(context, caplog):
|
||||
with context.socket(zmq.PUSH) as push, context.socket(zmq.PULL) as pull:
|
||||
push.linger = pull.linger = 0
|
||||
port = push.bind_to_random_port('tcp://127.0.0.1')
|
||||
pull.connect(f'tcp://127.0.0.1:{port}')
|
||||
shadow_pull = zmq.Socket.shadow(pull)
|
||||
stream = zmqstream.ZMQStream(shadow_pull)
|
||||
# send some messages
|
||||
for i in range(10):
|
||||
push.send_string(str(i))
|
||||
# make sure at least one message has been delivered
|
||||
pull.recv()
|
||||
# register callback
|
||||
# this should schedule event callback on the next tick
|
||||
stream.on_recv(print)
|
||||
# close the shadowed socket
|
||||
pull.close()
|
||||
# run the event loop, which should see some events on the shadow socket
|
||||
# but the socket has been closed!
|
||||
with warnings.catch_warnings(record=True) as records:
|
||||
await asyncio.sleep(0.2)
|
||||
warning_text = "\n".join(str(r.message) for r in records)
|
||||
assert "after closing socket" in warning_text
|
||||
assert "closed socket" in caplog.text
|
||||
Reference in New Issue
Block a user