Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-mcp-oauth-restore.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": patch
---

Export `DurableObjectOAuthClientProvider` from top-level `agents` package and fix `restoreConnectionsFromStorage()` to use the Agent's `createMcpOAuthProvider()` override instead of hardcoding the default provider
21 changes: 21 additions & 0 deletions docs/mcp-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,27 @@ class MyAgent extends Agent {

Your custom class must implement the `AgentMcpOAuthProvider` interface, which extends the MCP SDK's `OAuthClientProvider` with additional properties (`authUrl`, `clientId`, `serverId`) and methods (`checkState`, `consumeState`, `deleteCodeVerifier`) used by the agent's MCP connection lifecycle.

The override is used for both new connections (`addMcpServer`) and restored connections after a Durable Object restart, so your custom provider is always used consistently.

#### Custom storage backend

The most common customization is using a different storage backend while keeping the built-in OAuth logic (CSRF state, PKCE, nonce generation, token management). Import `DurableObjectOAuthClientProvider` and pass your own storage adapter:

```typescript
import { Agent, DurableObjectOAuthClientProvider } from "agents";
import type { AgentMcpOAuthProvider } from "agents";

class MyAgent extends Agent {
createMcpOAuthProvider(callbackUrl: string): AgentMcpOAuthProvider {
return new DurableObjectOAuthClientProvider(
myCustomStorage, // any DurableObjectStorage-compatible adapter
this.name,
callbackUrl
);
}
}
```

## Using MCP Capabilities

Once connected, access the server's capabilities:
Expand Down
11 changes: 7 additions & 4 deletions packages/agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,11 @@ function getNextCronTime(cron: string) {

export type { TransportType } from "./mcp/types";
export type { RetryOptions } from "./retries";
export type {
AgentMcpOAuthProvider,
export {
DurableObjectOAuthClientProvider,
type AgentMcpOAuthProvider,
/** @deprecated Use {@link AgentMcpOAuthProvider} instead. */
AgentsOAuthProvider
type AgentsOAuthProvider
} from "./mcp/do-oauth-client-provider";

/**
Expand Down Expand Up @@ -823,7 +824,9 @@ export class Agent<

// Initialize MCPClientManager AFTER tables are created
this.mcp = new MCPClientManager(this._ParentClass.name, "0.0.1", {
storage: this.ctx.storage
storage: this.ctx.storage,
createAuthProvider: (callbackUrl) =>
this.createMcpOAuthProvider(callbackUrl)
});

// Broadcast server state whenever MCP state changes (register, connect, OAuth, remove, etc.)
Expand Down
23 changes: 17 additions & 6 deletions packages/agents/src/mcp/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export type MCPClientOAuthResult =

export type MCPClientManagerOptions = {
storage: DurableObjectStorage;
createAuthProvider?: (callbackUrl: string) => AgentMcpOAuthProvider;
};

/**
Expand All @@ -127,6 +128,9 @@ export class MCPClientManager {
private _oauthCallbackConfig?: MCPClientOAuthCallbackConfig;
private _connectionDisposables = new Map<string, DisposableStore>();
private _storage: DurableObjectStorage;
private _createAuthProviderFn?: (
callbackUrl: string
) => AgentMcpOAuthProvider;
private _isRestored = false;

/** @internal Protected for testing purposes. */
Expand Down Expand Up @@ -159,6 +163,7 @@ export class MCPClientManager {
);
}
this._storage = options.storage;
this._createAuthProviderFn = options.createAuthProvider;
}

// SQL helper - runs a query and returns results as array
Expand Down Expand Up @@ -317,12 +322,18 @@ export class MCPClientManager {
? JSON.parse(server.server_options)
: null;

const authProvider = this.createAuthProvider(
server.id,
server.callback_url,
clientName,
server.client_id ?? undefined
);
const authProvider = this._createAuthProviderFn
? this._createAuthProviderFn(server.callback_url)
: this.createAuthProvider(
server.id,
server.callback_url,
clientName,
server.client_id ?? undefined
);
authProvider.serverId = server.id;
if (server.client_id) {
authProvider.clientId = server.client_id;
}

// Create the in-memory connection object (no need to save to storage - we just read from it!)
const conn = this.createConnection(server.id, server.server_url, {
Expand Down
37 changes: 37 additions & 0 deletions packages/agents/src/tests/agents/oauth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,41 @@ export class TestCustomOAuthAgent extends Agent<Record<string, unknown>> {
callbackUrl: this._customProviderCallbackUrl
};
}

async testRestoreUsesOverride(): Promise<{
overrideWasCalled: boolean;
restoredProviderClientId: string | undefined;
}> {
const serverId = "restore-override-test";
const callbackUrl = "http://example.com/restore-callback";

this.sql`
INSERT OR REPLACE INTO cf_agents_mcp_servers (
id, name, server_url, client_id, auth_url, callback_url, server_options
) VALUES (
${serverId},
${"Restore Test Server"},
${"http://restore-test.com"},
${null},
${"https://auth.example.com/authorize"},
${callbackUrl},
${null}
)
`;

// Reset restored flag so restoreConnectionsFromStorage runs again
// @ts-expect-error - accessing private property for testing
this.mcp._isRestored = false;
// Clear any existing connection for this server
delete this.mcp.mcpConnections[serverId];

this._customProviderCallbackUrl = undefined;
await this.mcp.restoreConnectionsFromStorage(this.name);

const conn = this.mcp.mcpConnections[serverId];
return {
overrideWasCalled: this._customProviderCallbackUrl === callbackUrl,
restoredProviderClientId: conn?.options.transport.authProvider?.clientId
};
}
}
219 changes: 219 additions & 0 deletions packages/agents/src/tests/mcp/client-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,225 @@ describe("MCPClientManager OAuth Integration", () => {
});
});

describe("restoreConnectionsFromStorage() - createAuthProvider factory", () => {
it("should use the injected factory when restoring connections", async () => {
const serverId = "factory-test";
const callbackUrl = "http://localhost:3000/callback";

saveServerToMock({
id: serverId,
name: "Factory Test Server",
server_url: "http://factory.com",
callback_url: callbackUrl,
client_id: null,
auth_url: null,
server_options: null
});

const mockProvider = createMockAuthProvider(createMockStateStorage());
const factory = vi.fn().mockReturnValue(mockProvider);

const factoryManager = new TestMCPClientManager("test-client", "1.0.0", {
storage: manager["_storage"],
createAuthProvider: factory
});

vi.spyOn(factoryManager, "connectToServer").mockResolvedValue({
state: "connected"
});

await factoryManager.restoreConnectionsFromStorage("test-agent");

expect(factory).toHaveBeenCalledWith(callbackUrl);
const conn = factoryManager.mcpConnections[serverId];
expect(conn).toBeDefined();
expect(conn.options.transport.authProvider).toBe(mockProvider);
});

it("should set serverId and clientId on the provider returned by the factory", async () => {
const serverId = "factory-ids-test";
const clientId = "custom-client-123";
const callbackUrl = "http://localhost:3000/callback";

saveServerToMock({
id: serverId,
name: "IDs Test Server",
server_url: "http://ids-test.com",
callback_url: callbackUrl,
client_id: clientId,
auth_url: null,
server_options: null
});

const mockProvider = createMockAuthProvider(createMockStateStorage());
const factory = vi.fn().mockReturnValue(mockProvider);

const factoryManager = new TestMCPClientManager("test-client", "1.0.0", {
storage: manager["_storage"],
createAuthProvider: factory
});

vi.spyOn(factoryManager, "connectToServer").mockResolvedValue({
state: "connected"
});

await factoryManager.restoreConnectionsFromStorage("test-agent");

expect(mockProvider.serverId).toBe(serverId);
expect(mockProvider.clientId).toBe(clientId);
});

it("should fall back to default createAuthProvider when no factory is provided", async () => {
const serverId = "no-factory-test";
const callbackUrl = "http://localhost:3000/callback";

saveServerToMock({
id: serverId,
name: "No Factory Server",
server_url: "http://no-factory.com",
callback_url: callbackUrl,
client_id: null,
auth_url: null,
server_options: null
});

vi.spyOn(manager, "connectToServer").mockResolvedValue({
state: "connected"
});

await manager.restoreConnectionsFromStorage("test-agent");

const conn = manager.mcpConnections[serverId];
expect(conn).toBeDefined();
expect(conn.options.transport.authProvider).toBeDefined();
expect(conn.options.transport.authProvider?.serverId).toBe(serverId);
});

it("should use the factory for OAuth servers with auth_url set", async () => {
const serverId = "factory-oauth-test";
const callbackUrl = "http://localhost:3000/callback";
const clientId = "oauth-client-id";

saveServerToMock({
id: serverId,
name: "OAuth Factory Server",
server_url: "http://oauth-factory.com",
callback_url: callbackUrl,
client_id: clientId,
auth_url: "https://auth.example.com/authorize",
server_options: null
});

const mockProvider = createMockAuthProvider(createMockStateStorage());
const factory = vi.fn().mockReturnValue(mockProvider);

const factoryManager = new TestMCPClientManager("test-client", "1.0.0", {
storage: manager["_storage"],
createAuthProvider: factory
});

await factoryManager.restoreConnectionsFromStorage("test-agent");

expect(factory).toHaveBeenCalledWith(callbackUrl);
const conn = factoryManager.mcpConnections[serverId];
expect(conn).toBeDefined();
expect(conn.connectionState).toBe("authenticating");
expect(mockProvider.serverId).toBe(serverId);
expect(mockProvider.clientId).toBe(clientId);
});

it("should use the factory when recreating failed connections", async () => {
const serverId = "factory-failed-test";
const callbackUrl = "http://localhost:3000/callback";

saveServerToMock({
id: serverId,
name: "Failed Server",
server_url: "http://failed-factory.com",
callback_url: callbackUrl,
client_id: null,
auth_url: null,
server_options: null
});

const mockProvider = createMockAuthProvider(createMockStateStorage());
const factory = vi.fn().mockReturnValue(mockProvider);

const factoryManager = new TestMCPClientManager("test-client", "1.0.0", {
storage: manager["_storage"],
createAuthProvider: factory
});

// Pre-populate with a failed connection
const failedConnection = new MCPClientConnection(
new URL("http://failed-factory.com"),
{ name: "test-client", version: "1.0.0" },
{ transport: { type: "auto" }, client: {} }
);
failedConnection.connectionState = "failed";
failedConnection.client.close = vi.fn().mockResolvedValue(undefined);
factoryManager.mcpConnections[serverId] = failedConnection;

vi.spyOn(factoryManager, "connectToServer").mockResolvedValue({
state: "connected"
});

await factoryManager.restoreConnectionsFromStorage("test-agent");

expect(factory).toHaveBeenCalledWith(callbackUrl);
expect(factoryManager.mcpConnections[serverId]).not.toBe(
failedConnection
);
});

it("should call the factory once per server in mixed restore", async () => {
const callbackUrl1 = "http://localhost:3000/callback/s1";
const callbackUrl2 = "http://localhost:3000/callback/s2";

saveServerToMock({
id: "server-1",
name: "Server 1",
server_url: "http://s1.com",
callback_url: callbackUrl1,
client_id: null,
auth_url: null,
server_options: null
});

saveServerToMock({
id: "server-2",
name: "Server 2",
server_url: "http://s2.com",
callback_url: callbackUrl2,
client_id: "client-2",
auth_url: "https://auth.example.com/authorize",
server_options: null
});

const mockProvider1 = createMockAuthProvider(createMockStateStorage());
const mockProvider2 = createMockAuthProvider(createMockStateStorage());
const factory = vi
.fn()
.mockReturnValueOnce(mockProvider1)
.mockReturnValueOnce(mockProvider2);

const factoryManager = new TestMCPClientManager("test-client", "1.0.0", {
storage: manager["_storage"],
createAuthProvider: factory
});

vi.spyOn(factoryManager, "connectToServer").mockResolvedValue({
state: "connected"
});

await factoryManager.restoreConnectionsFromStorage("test-agent");

expect(factory).toHaveBeenCalledTimes(2);
expect(factory).toHaveBeenCalledWith(callbackUrl1);
expect(factory).toHaveBeenCalledWith(callbackUrl2);
});
});

describe("connectToServer() - Connection States", () => {
it("should return connected state for successful non-OAuth connection", async () => {
const serverId = "non-oauth-connect-test";
Expand Down
14 changes: 14 additions & 0 deletions packages/agents/src/tests/mcp/create-oauth-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,18 @@ describe("createMcpOAuthProvider", () => {
expect(result.clientId).toBe("custom-client-id");
expect(result.callbackUrl).toBe("http://example.com/custom-callback");
});

it("should use the custom provider override during restoreConnectionsFromStorage", async () => {
const agentId = env.TestCustomOAuthAgent.idFromName(
"test-restore-override"
);
const agentStub = env.TestCustomOAuthAgent.get(agentId);

await agentStub.setName("restore-test");

const result = await agentStub.testRestoreUsesOverride();

expect(result.overrideWasCalled).toBe(true);
expect(result.restoredProviderClientId).toBe("custom-client-id");
});
});
Loading