Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/exo/master/adapters/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ def chunk_to_response(
)


async def sse_with_keepalive(
sse_stream: AsyncGenerator[str, None],
interval: float = 15.0,
) -> AsyncGenerator[str, None]:
import asyncio

queue: asyncio.Queue[str | None] = asyncio.Queue()

async def _producer() -> None:
try:
async for line in sse_stream:
await queue.put(line)
finally:
await queue.put(None)

task = asyncio.create_task(_producer())
try:
while True:
try:
item = await asyncio.wait_for(queue.get(), timeout=interval)
if item is None:
return
yield item
except asyncio.TimeoutError:
yield ": keepalive\n\n"
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass



async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[
Expand Down
16 changes: 10 additions & 6 deletions src/exo/master/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import contextlib
import json
Expand Down Expand Up @@ -25,6 +26,7 @@
chat_request_to_text_generation,
collect_chat_response,
generate_chat_stream,
sse_with_keepalive,
)
from exo.master.adapters.claude import (
claude_request_to_text_generation,
Expand Down Expand Up @@ -611,7 +613,7 @@ async def _token_chunk_stream(
if chunk.finish_reason is not None:
break

except anyio.get_cancelled_exc_class():
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
Expand Down Expand Up @@ -712,9 +714,11 @@ async def chat_completions(

if payload.stream:
return StreamingResponse(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
sse_with_keepalive(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
),
media_type="text/event-stream",
headers={
Expand Down Expand Up @@ -965,7 +969,7 @@ async def _generate_image_stream(
del image_total_chunks[key]
del image_metadata[key]

except anyio.get_cancelled_exc_class():
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
Expand Down Expand Up @@ -1051,7 +1055,7 @@ async def _collect_image_chunks(
)

return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
Expand Down