Skip to content

Commit 9f5562a

Browse files
Adds back support for partial tool calls during streaming + tests
1 parent 31d3250 commit 9f5562a

File tree

7 files changed

+533
-25
lines changed

7 files changed

+533
-25
lines changed

.changeset/ten-sloths-cross.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"workers-ai-provider": minor
3+
---
4+
5+
Adds support for new tool call format during streaming

packages/workers-ai-provider/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"@ai-sdk/provider": "^1.1.3"
4242
},
4343
"devDependencies": {
44-
"@cloudflare/workers-types": "^4.20250525.0"
44+
"@cloudflare/workers-types": "^4.20250525.0",
45+
"zod": "^3.25.28"
4546
}
4647
}

packages/workers-ai-provider/src/streaming.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ import { events } from "fetch-event-stream";
22

33
import type { LanguageModelV1StreamPart } from "@ai-sdk/provider";
44
import { mapWorkersAIUsage } from "./map-workersai-usage";
5+
import { processPartialToolCalls } from "./utils";
56

67
export function getMappedStream(response: Response) {
78
const chunkEvent = events(response);
89
let usage = { promptTokens: 0, completionTokens: 0 };
10+
const partialToolCalls: any[] = [];
911

1012
return new ReadableStream<LanguageModelV1StreamPart>({
1113
async start(controller) {
@@ -20,12 +22,26 @@ export function getMappedStream(response: Response) {
2022
if (chunk.usage) {
2123
usage = mapWorkersAIUsage(chunk);
2224
}
25+
if (chunk.tool_calls) {
26+
partialToolCalls.push(...chunk.tool_calls);
27+
}
2328
chunk.response?.length &&
2429
controller.enqueue({
2530
type: "text-delta",
2631
textDelta: chunk.response,
2732
});
2833
}
34+
35+
if (partialToolCalls.length > 0) {
36+
const toolCalls = processPartialToolCalls(partialToolCalls);
37+
toolCalls.map((toolCall) => {
38+
controller.enqueue({
39+
type: "tool-call",
40+
...toolCall,
41+
});
42+
});
43+
}
44+
2945
controller.enqueue({
3046
type: "finish",
3147
finishReason: "stop",

packages/workers-ai-provider/src/utils.ts

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { LanguageModelV1 } from "@ai-sdk/provider";
1+
import type { LanguageModelV1, LanguageModelV1FunctionToolCall } from "@ai-sdk/provider";
22

33
/**
44
* General AI run interface with overloads to handle distinct return types.
@@ -182,33 +182,79 @@ export function lastMessageWasUser<T extends { role: string }>(messages: T[]) {
182182
return messages.length > 0 && messages[messages.length - 1]!.role === "user";
183183
}
184184

185-
export function processToolCalls(output: any) {
185+
function mergePartialToolCalls(partialCalls: any[]) {
186+
const mergedCallsByIndex: any = {};
187+
188+
for (const partialCall of partialCalls) {
189+
const index = partialCall.index;
190+
191+
if (!mergedCallsByIndex[index]) {
192+
mergedCallsByIndex[index] = {
193+
id: partialCall.id || "",
194+
type: partialCall.type || "",
195+
function: {
196+
name: partialCall.function?.name || "",
197+
arguments: "",
198+
},
199+
};
200+
} else {
201+
if (partialCall.id) {
202+
mergedCallsByIndex[index].id = partialCall.id;
203+
}
204+
if (partialCall.type) {
205+
mergedCallsByIndex[index].type = partialCall.type;
206+
}
207+
208+
if (partialCall.function?.name) {
209+
mergedCallsByIndex[index].function.name = partialCall.function.name;
210+
}
211+
}
212+
213+
// Append arguments if available, this assumes arguments come in the right order
214+
if (partialCall.function?.arguments) {
215+
mergedCallsByIndex[index].function.arguments += partialCall.function.arguments;
216+
}
217+
}
218+
219+
return Object.values(mergedCallsByIndex);
220+
}
221+
222+
function processToolCall(toolCall: any): LanguageModelV1FunctionToolCall {
186223
// Check for OpenAI format tool calls first
224+
if (toolCall.function && toolCall.id) {
225+
return {
226+
toolCallType: "function",
227+
toolCallId: toolCall.id,
228+
toolName: toolCall.function.name,
229+
args:
230+
typeof toolCall.function.arguments === "string"
231+
? toolCall.function.arguments
232+
: JSON.stringify(toolCall.function.arguments || {}),
233+
};
234+
}
235+
return {
236+
toolCallType: "function",
237+
toolCallId: toolCall.name,
238+
toolName: toolCall.name,
239+
args:
240+
typeof toolCall.arguments === "string"
241+
? toolCall.arguments
242+
: JSON.stringify(toolCall.arguments || {}),
243+
};
244+
}
245+
246+
export function processToolCalls(output: any): LanguageModelV1FunctionToolCall[] {
187247
if (output.tool_calls && Array.isArray(output.tool_calls)) {
188248
return output.tool_calls.map((toolCall: any) => {
189-
// Handle new format
190-
if (toolCall.function && toolCall.id) {
191-
return {
192-
toolCallType: "function",
193-
toolCallId: toolCall.id,
194-
toolName: toolCall.function.name,
195-
args:
196-
typeof toolCall.function.arguments === "string"
197-
? toolCall.function.arguments
198-
: JSON.stringify(toolCall.function.arguments || {}),
199-
};
200-
}
201-
return {
202-
toolCallType: "function",
203-
toolCallId: toolCall.name,
204-
toolName: toolCall.name,
205-
args:
206-
typeof toolCall.arguments === "string"
207-
? toolCall.arguments
208-
: JSON.stringify(toolCall.arguments || {}),
209-
};
249+
const processedToolCall = processToolCall(toolCall);
250+
return processedToolCall;
210251
});
211252
}
212253

213254
return [];
214255
}
256+
257+
export function processPartialToolCalls(partialToolCalls: any[]) {
258+
const mergedToolCalls = mergePartialToolCalls(partialToolCalls);
259+
return processToolCalls({ tool_calls: mergedToolCalls });
260+
}

0 commit comments

Comments
 (0)