Skip to content

Commit 54f02ed

Browse files
committed
fix: propagate pre-endpoint errors in sse_client instead of deadlocking
When sse_reader encounters an error before receiving the endpoint event, the except handler tried to send the exception to read_stream_writer. With a zero-buffer stream and no reader (the caller is still blocked in tg.start() waiting for task_status.started()), send() blocks forever. Track whether started() has fired. Before it, re-raise so the exception propagates through tg.start(). After it, send to the stream as before. This also adds a guard for the case where a server sends a message event before the endpoint event, which would deadlock on the same send() call. The dedicated SSEError handler from #975 is removed since the started flag now handles all pre-endpoint exceptions uniformly. Github-Issue: #447
1 parent 7ba4fb8 commit 54f02ed

File tree

2 files changed

+119
-9
lines changed

2 files changed

+119
-9
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from anyio.abc import TaskStatus
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111
from httpx_sse import aconnect_sse
12-
from httpx_sse._exceptions import SSEError
1312

1413
from mcp import types
1514
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
@@ -69,6 +68,12 @@ async def sse_client(
6968
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
7069

7170
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
71+
# Before task_status.started() fires, the caller is blocked inside
72+
# tg.start() and nobody reads from read_stream. Sending to the
73+
# zero-buffer stream in that phase would deadlock, so errors must
74+
# be raised instead. After started(), the caller has the streams
75+
# and errors are delivered through read_stream.
76+
started = False
7277
try:
7378
async for sse in event_source.aiter_sse(): # pragma: no branch
7479
logger.debug(f"Received SSE event: {sse.event}")
@@ -79,27 +84,28 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
7984

8085
url_parsed = urlparse(url)
8186
endpoint_parsed = urlparse(endpoint_url)
82-
if ( # pragma: no cover
87+
if (
8388
url_parsed.netloc != endpoint_parsed.netloc
8489
or url_parsed.scheme != endpoint_parsed.scheme
8590
):
86-
error_msg = ( # pragma: no cover
91+
raise ValueError(
8792
f"Endpoint origin does not match connection origin: {endpoint_url}"
8893
)
89-
logger.error(error_msg) # pragma: no cover
90-
raise ValueError(error_msg) # pragma: no cover
9194

9295
if on_session_created:
9396
session_id = _extract_session_id_from_endpoint(endpoint_url)
9497
if session_id:
9598
on_session_created(session_id)
9699

97100
task_status.started(endpoint_url)
101+
started = True
98102

99103
case "message":
100104
# Skip empty data (keep-alive pings)
101105
if not sse.data:
102106
continue
107+
if not started:
108+
raise RuntimeError("Received message event before endpoint event")
103109
try:
104110
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
105111
logger.debug(f"Received server message: {message}")
@@ -112,11 +118,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
112118
await read_stream_writer.send(session_message)
113119
case _: # pragma: no cover
114120
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
115-
except SSEError as sse_exc: # pragma: lax no cover
116-
logger.exception("Encountered SSE exception")
117-
raise sse_exc
118-
except Exception as exc: # pragma: lax no cover
121+
except Exception as exc:
119122
logger.exception("Error in sse_reader")
123+
if not started:
124+
raise
120125
await read_stream_writer.send(exc)
121126
finally:
122127
await read_stream_writer.aclose()

tests/shared/test_sse.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import json
22
import multiprocessing
33
import socket
4+
import sys
45
from collections.abc import AsyncGenerator, Generator
56
from typing import Any
67
from unittest.mock import AsyncMock, MagicMock, Mock, patch
78
from urllib.parse import urlparse
89

10+
# BaseExceptionGroup is builtin on 3.11+. On 3.10 it comes from the
11+
# exceptiongroup backport, which anyio pulls in as a dependency.
12+
if sys.version_info < (3, 11): # pragma: lax no cover
13+
from exceptiongroup import BaseExceptionGroup
14+
915
import anyio
1016
import httpx
1117
import pytest
@@ -604,6 +610,105 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
604610
assert msg.message.id == 1
605611

606612

613+
def _mock_sse_connection(aiter_sse: AsyncGenerator[ServerSentEvent, None]) -> Any:
614+
"""Patch sse_client's HTTP layer to yield the given SSE event stream."""
615+
mock_event_source = MagicMock()
616+
mock_event_source.aiter_sse.return_value = aiter_sse
617+
mock_event_source.response.raise_for_status = MagicMock()
618+
619+
mock_aconnect_sse = MagicMock()
620+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
621+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
622+
623+
mock_client = MagicMock()
624+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
625+
mock_client.__aexit__ = AsyncMock(return_value=None)
626+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
627+
628+
return patch.multiple(
629+
"mcp.client.sse",
630+
create_mcp_http_client=Mock(return_value=mock_client),
631+
aconnect_sse=Mock(return_value=mock_aconnect_sse),
632+
)
633+
634+
635+
@pytest.mark.anyio
636+
async def test_sse_client_raises_on_endpoint_origin_mismatch() -> None:
637+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
638+
639+
When the server sends an endpoint URL with a different origin than the
640+
connection URL, sse_client must raise promptly instead of deadlocking.
641+
Before the fix, the ValueError was caught and sent to a zero-buffer stream
642+
with no reader, hanging forever.
643+
"""
644+
645+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
646+
yield ServerSentEvent(event="endpoint", data="http://wrong-host:9999/messages?sessionId=abc")
647+
await anyio.sleep_forever() # pragma: no cover
648+
649+
with _mock_sse_connection(events()), anyio.fail_after(5):
650+
with pytest.raises(BaseExceptionGroup) as exc_info:
651+
async with sse_client("http://test/sse"): # pragma: no branch
652+
pytest.fail("sse_client should not yield on origin mismatch") # pragma: no cover
653+
assert exc_info.group_contains(ValueError, match="Endpoint origin does not match")
654+
655+
656+
@pytest.mark.anyio
657+
async def test_sse_client_raises_on_error_before_endpoint() -> None:
658+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
659+
660+
Any exception raised while waiting for the endpoint event must propagate
661+
instead of deadlocking on the zero-buffer read stream.
662+
"""
663+
664+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
665+
raise ConnectionError("connection reset by peer")
666+
yield # pragma: no cover
667+
668+
with _mock_sse_connection(events()), anyio.fail_after(5):
669+
with pytest.raises(BaseExceptionGroup) as exc_info:
670+
async with sse_client("http://test/sse"): # pragma: no branch
671+
pytest.fail("sse_client should not yield on pre-endpoint error") # pragma: no cover
672+
assert exc_info.group_contains(ConnectionError, match="connection reset")
673+
674+
675+
@pytest.mark.anyio
676+
async def test_sse_client_raises_on_message_before_endpoint() -> None:
677+
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/447
678+
679+
If the server sends a message event before the endpoint event (protocol
680+
violation), sse_client must raise rather than deadlock trying to send the
681+
message to a stream nobody is reading yet.
682+
"""
683+
684+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
685+
yield ServerSentEvent(event="message", data='{"jsonrpc":"2.0","id":1,"result":{}}')
686+
await anyio.sleep_forever() # pragma: no cover
687+
688+
with _mock_sse_connection(events()), anyio.fail_after(5):
689+
with pytest.raises(BaseExceptionGroup) as exc_info:
690+
async with sse_client("http://test/sse"): # pragma: no branch
691+
pytest.fail("sse_client should not yield on protocol violation") # pragma: no cover
692+
assert exc_info.group_contains(RuntimeError, match="before endpoint event")
693+
694+
695+
@pytest.mark.anyio
696+
async def test_sse_client_delivers_post_endpoint_errors_via_stream() -> None:
697+
"""After the endpoint is received, errors in sse_reader are delivered on the
698+
read stream so the session can handle them, rather than crashing the task group.
699+
"""
700+
701+
async def events() -> AsyncGenerator[ServerSentEvent, None]:
702+
yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc")
703+
raise ConnectionError("mid-stream failure")
704+
705+
with _mock_sse_connection(events()), anyio.fail_after(5):
706+
async with sse_client("http://test/sse") as (read_stream, _):
707+
received = await read_stream.receive()
708+
assert isinstance(received, ConnectionError)
709+
assert "mid-stream failure" in str(received)
710+
711+
607712
@pytest.mark.anyio
608713
async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None:
609714
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227

0 commit comments

Comments
 (0)