|
1 | 1 | import json |
2 | 2 | import multiprocessing |
3 | 3 | import socket |
| 4 | +import sys |
4 | 5 | from collections.abc import AsyncGenerator, Generator |
5 | 6 | from typing import Any |
6 | 7 | from unittest.mock import AsyncMock, MagicMock, Mock, patch |
7 | 8 | from urllib.parse import urlparse |
8 | 9 |
|
| 10 | +# BaseExceptionGroup is builtin on 3.11+. On 3.10 it comes from the |
| 11 | +# exceptiongroup backport, which anyio pulls in as a dependency. |
| 12 | +if sys.version_info < (3, 11): # pragma: lax no cover |
| 13 | + from exceptiongroup import BaseExceptionGroup |
| 14 | + |
9 | 15 | import anyio |
10 | 16 | import httpx |
11 | 17 | import pytest |
@@ -604,6 +610,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: |
604 | 610 | assert msg.message.id == 1 |
605 | 611 |
|
606 | 612 |
|
| 613 | +def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any: |
| 614 | + """Patch sse_client's HTTP layer to yield the given SSE event stream.""" |
| 615 | + mock_event_source = MagicMock() |
| 616 | + mock_event_source.aiter_sse.return_value = aiter_sse |
| 617 | + mock_event_source.response.raise_for_status = MagicMock() |
| 618 | + |
| 619 | + mock_aconnect_sse = MagicMock() |
| 620 | + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) |
| 621 | + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) |
| 622 | + |
| 623 | + mock_client = MagicMock() |
| 624 | + mock_client.__aenter__ = AsyncMock(return_value=mock_client) |
| 625 | + mock_client.__aexit__ = AsyncMock(return_value=None) |
| 626 | + mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock())) |
| 627 | + |
| 628 | + return patch.multiple( |
| 629 | + "mcp.client.sse", |
| 630 | + create_mcp_http_client=Mock(return_value=mock_client), |
| 631 | + aconnect_sse=Mock(return_value=mock_aconnect_sse), |
| 632 | + ) |
| 633 | + |
| 634 | + |
| 635 | +@pytest.mark.anyio |
| 636 | +async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None: |
| 637 | + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 |
| 638 | +
|
| 639 | + When the server sends an endpoint URL with a different origin than the |
| 640 | + connection URL, sse_client must raise promptly instead of deadlocking. |
| 641 | + Before the fix, the ValueError was caught and sent to a zero-buffer stream |
| 642 | + with no reader, hanging forever. |
| 643 | + """ |
| 644 | + |
| 645 | + async def events() -> AsyncGenerator[ServerSentEvent, None]: |
| 646 | + yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc") |
| 647 | + await anyio.sleep_forever() # pragma: no cover |
| 648 | + |
| 649 | + with _mock_sse_connection(events()), anyio.fail_after(5): |
| 650 | + with pytest.raises(BaseExceptionGroup) as exc_info: |
| 651 | + async with sse_client("http://test/sse"): # pragma: no branch |
| 652 | + pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover |
| 653 | + assert exc_info.group_contains(ValueError, match="Endpoint origin does not match") |
| 654 | + |
| 655 | + |
| 656 | +@pytest.mark.anyio |
| 657 | +async def test_sse_client_raises_on_error_before_endpoint() -> None: |
| 658 | + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 |
| 659 | +
|
| 660 | + Any exception raised while waiting for the endpoint event must propagate |
| 661 | + instead of deadlocking on the zero-buffer read stream. |
| 662 | + """ |
| 663 | + |
| 664 | + async def events() -> AsyncGenerator[ServerSentEvent, None]: |
| 665 | + raise ConnectionError("connection reset by peer") |
| 666 | + yield # pragma: no cover |
| 667 | + |
| 668 | + with _mock_sse_connection(events()), anyio.fail_after(5): |
| 669 | + with pytest.raises(BaseExceptionGroup) as exc_info: |
| 670 | + async with sse_client("http://test/sse"): # pragma: no branch |
| 671 | + pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover |
| 672 | + assert exc_info.group_contains(ConnectionError, match="connection reset") |
| 673 | + |
| 674 | + |
| 675 | +@pytest.mark.anyio |
| 676 | +async def test_sse_client_raises_on_message_before_endpoint() -> None: |
| 677 | + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447 |
| 678 | +
|
| 679 | + If the server sends a message event before the endpoint event (protocol |
| 680 | + violation), sse_client must raise rather than deadlock trying to send the |
| 681 | + message to a stream nobody is reading yet. |
| 682 | + """ |
| 683 | + |
| 684 | + async def events() -> AsyncGenerator[ServerSentEvent, None]: |
| 685 | + yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}') |
| 686 | + await anyio.sleep_forever() # pragma: no cover |
| 687 | + |
| 688 | + with _mock_sse_connection(events()), anyio.fail_after(5): |
| 689 | + with pytest.raises(BaseExceptionGroup) as exc_info: |
| 690 | + async with sse_client("http://test/sse"): # pragma: no branch |
| 691 | + pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover |
| 692 | + assert exc_info.group_contains(RuntimeError, match="before endpoint event") |
| 693 | + |
| 694 | + |
| 695 | +@pytest.mark.anyio |
| 696 | +async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None: |
| 697 | + """After the endpoint is received, errors in sse_reader are delivered on the |
| 698 | + read stream so the session can handle them, rather than crashing the task group. |
| 699 | + """ |
| 700 | + |
| 701 | + async def events() -> AsyncGenerator[ServerSentEvent, None]: |
| 702 | + yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc") |
| 703 | + raise ConnectionError("mid-stream failure") |
| 704 | + |
| 705 | + with _mock_sse_connection(events()), anyio.fail_after(5): |
| 706 | + async with sse_client("http://test/sse") as (read_stream, _): |
| 707 | + received = await read_stream.receive() |
| 708 | + assert isinstance(received, ConnectionError) |
| 709 | + assert "mid-stream failure" in str(received) |
| 710 | + |
| 711 | + |
607 | 712 | @pytest.mark.anyio |
608 | 713 | async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: |
609 | 714 | """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 |
|
0 commit comments