Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/swift-moments-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": patch
---

add auto transport option
1 change: 1 addition & 0 deletions packages/agents/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ const result = await generateText({

**Transport Options:**

- **Auto**: Automatically determine the correct transport
- **HTTP Streamable**: Best performance, batch requests, session management
- **SSE**: Simple setup, legacy compatibility

Expand Down
97 changes: 70 additions & 27 deletions packages/agents/src/mcp/client-connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import {
type ListResourceTemplatesResult,
type ListResourcesResult,
type ListToolsResult,
// type Notification,
type Prompt,
PromptListChangedNotificationSchema,
type Resource,
Expand All @@ -29,9 +28,11 @@ export type MCPTransportOptions = (
| StreamableHTTPClientTransportOptions
) & {
authProvider?: AgentsOAuthProvider;
type?: "sse" | "streamable-http";
type?: "sse" | "streamable-http" | "auto";
};

type TransportType = Exclude<MCPTransportOptions["type"], "auto">;

export class MCPClientConnection {
client: Client;
connectionState:
Expand Down Expand Up @@ -75,31 +76,7 @@ export class MCPClientConnection {
async init(code?: string) {
try {
const transportType = this.options.transport.type || "streamable-http";
const transport =
transportType === "streamable-http"
? new StreamableHTTPEdgeClientTransport(
this.url,
this.options.transport as StreamableHTTPClientTransportOptions
)
: new SSEEdgeClientTransport(
this.url,
this.options.transport as SSEClientTransportOptions
);

if (code) {
await transport.finishAuth(code);
}

await this.client.connect(transport);

// Set up elicitation request handler
this.client.setRequestHandler(
ElicitRequestSchema,
async (request: ElicitRequest) => {
return await this.handleElicitationRequest(request);
}
);

await this.tryConnect(transportType, code);
// biome-ignore lint/suspicious/noExplicitAny: allow for the error check here
} catch (e: any) {
if (e.toString().includes("Unauthorized")) {
Expand Down Expand Up @@ -301,6 +278,72 @@ export class MCPClientConnection {
"Elicitation handler must be implemented for your platform. Override handleElicitationRequest method."
);
}
/**
* Get the transport for the client
* @param transportType - The transport type to get
* @returns The transport for the client
*/
getTransport(transportType: TransportType) {
switch (transportType) {
case "streamable-http":
return new StreamableHTTPEdgeClientTransport(
this.url,
this.options.transport as StreamableHTTPClientTransportOptions
);
case "sse":
return new SSEEdgeClientTransport(
this.url,
this.options.transport as SSEClientTransportOptions
);
default:
throw new Error(`Unsupported transport type: ${transportType}`);
}
}

async tryConnect(transportType: MCPTransportOptions["type"], code?: string) {
const transports: TransportType[] =
transportType === "auto" ? ["streamable-http", "sse"] : [transportType];

for (const currentTransportType of transports) {
const isLastTransport =
currentTransportType === transports[transports.length - 1];
const hasFallback =
transportType === "auto" &&
currentTransportType === "streamable-http" &&
!isLastTransport;

const transport = await this.getTransport(currentTransportType);

if (code) {
await transport.finishAuth(code);
}

try {
await this.client.connect(transport);
break;
} catch (e) {
const error = e instanceof Error ? e : new Error(String(e));

if (
hasFallback &&
(error.message.includes("404") || error.message.includes("405"))
) {
// try the next transport if we have a fallback
continue;
}

throw e;
}
}

// Set up elicitation request handler
this.client.setRequestHandler(
ElicitRequestSchema,
async (request: ElicitRequest) => {
return await this.handleElicitationRequest(request);
}
);
}
}

function capabilityErrorHandler<T>(empty: T, method: string) {
Expand Down
9 changes: 4 additions & 5 deletions packages/agents/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ export abstract class McpAgent<
// will be passed here
async onSSEMcpMessage(
_sessionId: string,
request: Request
messageBody: unknown
): Promise<Error | null> {
if (this._status !== "started") {
// This means the server "woke up" after hibernation
Expand All @@ -612,10 +612,9 @@ export abstract class McpAgent<
}

try {
const message = await request.json();
let parsedMessage: JSONRPCMessage;
try {
parsedMessage = JSONRPCMessageSchema.parse(message);
parsedMessage = JSONRPCMessageSchema.parse(messageBody);
} catch (error) {
this._transport?.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -890,8 +889,8 @@ export abstract class McpAgent<
const id = namespace.idFromName(`sse:${sessionId}`);
const doStub = namespace.get(id);

// Forward the request to the Durable Object
const error = await doStub.onSSEMcpMessage(sessionId, request);
const messageBody = await request.json();
const error = await doStub.onSSEMcpMessage(sessionId, messageBody);
Comment on lines +892 to +893
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was struggling to get the tests to pass when passing the entire request to onSSEMcpMessage. I noticed that the only thing it actually needs is the message body so being explicit seemed okay.


if (error) {
return new Response(error.message, {
Expand Down
84 changes: 84 additions & 0 deletions packages/agents/src/tests/mcp/transports/auto.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { createExecutionContext, env } from "cloudflare:test";
import { describe, expect, it, beforeEach, afterEach } from "vitest";
import worker, { type Env } from "../../worker";
import { initializeMCPClientConnection } from "../../shared/test-utils";

declare module "cloudflare:test" {
interface ProvidedEnv extends Env {}
}

/**
* Tests for the "auto" transport mode which attempts streamable-http first,
* then falls back to SSE if streamable-http returns 404 or 405 errors
*/
describe("Auto Transport Mode", () => {
const originalFetch = globalThis.fetch;

beforeEach(() => {
globalThis.fetch = async (input: RequestInfo | URL, init?: RequestInit) => {
const ctx = createExecutionContext();
const request = new Request(input, init);
return worker.fetch(request, env, ctx);
};
});

afterEach(() => {
globalThis.fetch = originalFetch;
});

describe("Transport Selection Logic", () => {
it("should use connect using streamable-http when available", async () => {
const connection = await initializeMCPClientConnection(
"http://example.com/mcp",
"auto"
);

await connection.init();

expect(connection.connectionState).toBe("ready");
expect(connection.tools).toBeDefined();
});

it("should use connect using sse when available", async () => {
const connection = await initializeMCPClientConnection(
"http://example.com/sse",
"auto"
);

await connection.init();

expect(connection.connectionState).toBe("ready");
expect(connection.tools).toBeDefined();
});

it("should not fallback for 5XX errors", async () => {
const connection = await initializeMCPClientConnection(
"http://example.com/500",
"auto"
);

await expect(connection.init()).rejects.toThrow();
expect(connection.connectionState).toBe("failed");
});

it("should fail when endpoint returns 404 for both streamable-http and sse", async () => {
const connection = await initializeMCPClientConnection(
"http://example.com/not-found",
"auto"
);

await expect(connection.init()).rejects.toThrow();
expect(connection.connectionState).toBe("failed");
});

it("should fail when asking for an incorrect transport type", async () => {
const connection = await initializeMCPClientConnection(
"http://example.com/mcp",
"sse"
);

await expect(connection.init()).rejects.toThrow();
expect(connection.connectionState).toBe("failed");
});
});
});
12 changes: 12 additions & 0 deletions packages/agents/src/tests/shared/test-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { env } from "cloudflare:test";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { expect } from "vitest";
import worker, { type Env } from "../worker";
import { MCPClientConnection } from "../../mcp/client-connection";

declare module "cloudflare:test" {
interface ProvidedEnv extends Env {}
Expand Down Expand Up @@ -136,6 +137,17 @@ export async function initializeStreamableHTTPServer(
return sessionId as string;
}

export async function initializeMCPClientConnection(
baseUrl = "http://example.com/mcp",
transportType: "auto" | "streamable-http" | "sse" = "auto"
) {
return new MCPClientConnection(
new URL(baseUrl),
{ name: "test-client", version: "1.0.0" },
{ transport: { type: transportType }, client: {} }
);
}

/**
* Helper to establish SSE connection and get session ID
*/
Expand Down
4 changes: 4 additions & 0 deletions packages/agents/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ export default {
return TestMcpAgent.serve("/mcp").fetch(request, env, ctx);
}

if (url.pathname === "/500") {
return new Response("Internal Server Error", { status: 500 });
}

return new Response("Not found", { status: 404 });
},

Expand Down