Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import builtins
from typing import Any, Dict, List, Literal, Mapping

from pydantic import BaseModel
Expand Down Expand Up @@ -55,12 +56,12 @@ async def call_tool(
try:
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token))
cancellation_token.link_future(result_future)
result = await result_future
actual_tool_output = await result_future
is_error = False
result_str = tool.return_value_as_string(actual_tool_output)
except Exception as e:
result = str(e)
result_str = self._format_errors(e)
is_error = True
result_str = tool.return_value_as_string(result)
return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)

async def start(self) -> None:
Expand Down Expand Up @@ -90,3 +91,16 @@ def _to_config(self) -> StaticWorkbenchConfig:
@classmethod
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])

def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""

error_message = ""
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
# ExceptionGroup is available in Python 3.11+.
# TODO: how to make this compatible with Python 3.10?
for sub_exception in error.exceptions: # type: ignore
error_message += self._format_errors(sub_exception) # type: ignore
else:
error_message += f"{str(error)}\n"
return error_message.strip()
56 changes: 37 additions & 19 deletions python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,52 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)

def _normalize_payload_to_content_list(
self, payload: list[TextContent | ImageContent | EmbeddedResource]
) -> list[TextContent | ImageContent | EmbeddedResource]:
"""
Normalizes a raw tool output payload into a list of content items.
- If payload is already a list of (TextContent, ImageContent, EmbeddedResource), it's returned as is.
- If payload is a single TextContent, ImageContent, or EmbeddedResource, it's wrapped in a list.
- If payload is a string, it's wrapped in [TextContent(text=payload)].
- Otherwise, the payload is stringified and wrapped in [TextContent(text=str(payload))].
"""
if isinstance(payload, list) and all(
isinstance(item, (TextContent, ImageContent, EmbeddedResource)) for item in payload
):
return payload
elif isinstance(payload, (TextContent, ImageContent, EmbeddedResource)):
return [payload]
elif isinstance(payload, str):
return [TextContent(text=payload, type="text")]
else:
return [TextContent(text=str(payload), type="text")]

async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
exceptions_to_catch: tuple[Type[BaseException], ...]
if hasattr(builtins, "ExceptionGroup"):
exceptions_to_catch = (asyncio.CancelledError, builtins.ExceptionGroup)
else:
exceptions_to_catch = (asyncio.CancelledError,)

try:
if cancellation_token.is_cancelled():
raise Exception("Operation cancelled")
raise asyncio.CancelledError("Operation cancelled")

result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
cancellation_token.link_future(result_future)
result = await result_future

normalized_content_list = self._normalize_payload_to_content_list(result.content)

if result.isError:
raise Exception(f"MCP tool execution failed: {result.content}")
return result.content
except Exception as e:
error_message = self._format_errors(e)
raise Exception(error_message) from e
serialized_error_message = self.return_value_as_string(normalized_content_list)
raise Exception(serialized_error_message)
return normalized_content_list

except exceptions_to_catch:
# Re-raise these specific exception types directly.
raise

@classmethod
async def from_server_params(cls, server_params: TServerParams, tool_name: str) -> "McpToolAdapter[TServerParams]":
Expand Down Expand Up @@ -138,16 +169,3 @@ def serialize_item(item: Any) -> dict[str, Any]:
return {}

return json.dumps([serialize_item(item) for item in value])

def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""

error_message = ""
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
# ExceptionGroup is available in Python 3.11+.
# TODO: how to make this compatible with Python 3.10?
for sub_exception in error.exceptions: # type: ignore
error_message += self._format_errors(sub_exception) # type: ignore
else:
error_message += f"{str(error)}\n"
return error_message
119 changes: 119 additions & 0 deletions python/packages/autogen-ext/tests/tools/test_mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ def cancellation_token() -> CancellationToken:
return CancellationToken()


@pytest.fixture
def mock_error_tool_response() -> MagicMock:
response = MagicMock()
response.isError = True
response.content = [TextContent(text="error output", type="text")]
return response


def test_adapter_config_serialization(sample_tool: Tool, sample_server_params: StdioServerParams) -> None:
"""Test that adapter can be saved to and loaded from config."""
original_adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool)
Expand Down Expand Up @@ -650,3 +658,114 @@ def test_del_raises_when_loop_closed() -> None:

with pytest.warns(RuntimeWarning, match="loop is closed or not running"):
del workbench


def test_mcp_tool_adapter_normalize_payload(sample_tool: Tool, sample_server_params: StdioServerParams) -> None:
"""Test the _normalize_payload_to_content_list method of McpToolAdapter."""
adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool)

# Case 1: Payload is already a list of valid content items
valid_content_list: list[TextContent | ImageContent | EmbeddedResource] = [
TextContent(text="hello", type="text"),
ImageContent(data="base64data", mimeType="image/png", type="image"),
EmbeddedResource(
type="resource",
resource=TextResourceContents(text="embedded text", uri=AnyUrl(url="http://example.com/resource")),
),
]
assert adapter._normalize_payload_to_content_list(valid_content_list) == valid_content_list # type: ignore[reportPrivateUsage]

# Case 2: Payload is a single TextContent
single_text_content = TextContent(text="single text", type="text")
assert adapter._normalize_payload_to_content_list(single_text_content) == [single_text_content] # type: ignore[reportPrivateUsage, arg-type]

# Case 3: Payload is a single ImageContent
single_image_content = ImageContent(data="imagedata", mimeType="image/jpeg", type="image")
assert adapter._normalize_payload_to_content_list(single_image_content) == [single_image_content] # type: ignore[reportPrivateUsage, arg-type]

# Case 4: Payload is a single EmbeddedResource
single_embedded_resource = EmbeddedResource(
type="resource",
resource=TextResourceContents(text="other embedded", uri=AnyUrl(url="http://example.com/other")),
)
assert adapter._normalize_payload_to_content_list(single_embedded_resource) == [single_embedded_resource] # type: ignore[reportPrivateUsage, arg-type]

# Case 5: Payload is a string
string_payload = "This is a string payload."
expected_from_string = [TextContent(text=string_payload, type="text")]
assert adapter._normalize_payload_to_content_list(string_payload) == expected_from_string # type: ignore[reportPrivateUsage, arg-type]

# Case 6: Payload is an integer
int_payload = 12345
expected_from_int = [TextContent(text=str(int_payload), type="text")]
assert adapter._normalize_payload_to_content_list(int_payload) == expected_from_int # type: ignore[reportPrivateUsage, arg-type]

# Case 7: Payload is a dictionary
dict_payload = {"key": "value", "number": 42}
expected_from_dict = [TextContent(text=str(dict_payload), type="text")]
assert adapter._normalize_payload_to_content_list(dict_payload) == expected_from_dict # type: ignore[reportPrivateUsage, arg-type]

# Case 8: Payload is an empty list (should still be a list of valid items, so returns as is)
empty_list_payload: list[TextContent | ImageContent | EmbeddedResource] = []
assert adapter._normalize_payload_to_content_list(empty_list_payload) == empty_list_payload # type: ignore[reportPrivateUsage]

# Case 9: Payload is None (should be stringified)
none_payload = None
expected_from_none = [TextContent(text=str(none_payload), type="text")]
assert adapter._normalize_payload_to_content_list(none_payload) == expected_from_none # type: ignore[reportPrivateUsage, arg-type]


@pytest.mark.asyncio
async def test_mcp_tool_adapter_run_error(
sample_tool: Tool,
sample_server_params: StdioServerParams,
mock_session: AsyncMock,
mock_error_tool_response: MagicMock,
cancellation_token: CancellationToken,
) -> None:
"""Test McpToolAdapter._run when tool returns an error."""
adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool, session=mock_session)
mock_session.call_tool.return_value = mock_error_tool_response

args = {"test_param": "test_value"}
with pytest.raises(Exception) as excinfo:
await adapter._run(args=args, cancellation_token=cancellation_token, session=mock_session) # type: ignore[reportPrivateUsage]

mock_session.call_tool.assert_called_once_with(name=sample_tool.name, arguments=args)
assert adapter.return_value_as_string([TextContent(text="error output", type="text")]) in str(excinfo.value)


@pytest.mark.asyncio
async def test_mcp_tool_adapter_run_cancelled_before_call(
sample_tool: Tool,
sample_server_params: StdioServerParams,
mock_session: AsyncMock,
cancellation_token: CancellationToken,
) -> None:
"""Test McpToolAdapter._run when operation is cancelled before tool call."""
adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool, session=mock_session)
cancellation_token.cancel() # Cancel before the call

args = {"test_param": "test_value"}
with pytest.raises(asyncio.CancelledError):
await adapter._run(args=args, cancellation_token=cancellation_token, session=mock_session) # type: ignore[reportPrivateUsage]

mock_session.call_tool.assert_not_called()


@pytest.mark.asyncio
async def test_mcp_tool_adapter_run_cancelled_during_call(
sample_tool: Tool,
sample_server_params: StdioServerParams,
mock_session: AsyncMock,
cancellation_token: CancellationToken,
) -> None:
"""Test McpToolAdapter._run when operation is cancelled during tool call."""
adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool, session=mock_session)
mock_session.call_tool.side_effect = asyncio.CancelledError("Tool call cancelled")

args = {"test_param": "test_value"}
with pytest.raises(asyncio.CancelledError):
await adapter._run(args=args, cancellation_token=cancellation_token, session=mock_session) # type: ignore[reportPrivateUsage]

mock_session.call_tool.assert_called_once_with(name=sample_tool.name, arguments=args)
Loading