Skip to content

Inference streaming support#1750

Merged
RobertSamoilescu merged 58 commits intoSeldonIO:masterfrom
RobertSamoilescu:feature/inference-streaming-poc
May 22, 2024
Merged

Inference streaming support#1750
RobertSamoilescu merged 58 commits intoSeldonIO:masterfrom
RobertSamoilescu:feature/inference-streaming-poc

Conversation

@RobertSamoilescu
Copy link
Copy Markdown
Contributor

@RobertSamoilescu RobertSamoilescu commented May 9, 2024

This PR includes streaming support for MLServer by allowing the user to implement in the runtime the predict_stream method which expects as input a async generator of request an outputs a async generator of response.

class MyModel(MLModel):

    async def predict(self, payload: InferenceRequest) -> InferenceResponse:
	    pass

    async def predict_stream(
        self, payloads: AsyncIterator[InferenceRequest]
    ) -> AsyncIterator[InferenceResponse]:
	    pass

While the input-output types for the predict remain the same, for the predict_stream the implementation can handle a stream of inputs and a stream of outputs. This design choice is quite general and can cover many input-output scenarios:

  • unary input - unary output (handled by predict)
  • unary input - stream output (handled by predict_stream)
  • stream input - unary output (handled by predict_stream)
  • stream input - stream output (handled by predict_stream )

Although for REST, streamed input might not be a thing and currently not supported, for gRPC it is quite natural to have. In the case that a user will like to use streamed inputs, then they will have to use gRPC.

Exposed endpoints

We expose the following endpoints (+ the ones including the version) to the user:

  • /v2/models/{model_name}/infer
  • /v2/models/{model_name}/infer_stream
  • /v2/models/{model_name}/generate
  • /v2/models/{model_name}/generate_stream

The first two are general purpose endpoints while the later two are LLM specific (see open inference protocol here). Note that the infer and generate endpoints will point to the infer implementation while infer_stream and generate_stream will point to infer_stream implementation defined above.

Client calls

REST non-streaming

import os
import requests
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)


api_url = "http://localhost:8080/v2/models/text-model/generate"
response = requests.post(api_url, json=inference_request.dict())
response = types.InferenceResponse.parse_raw(response.text)
print(StringCodec.decode_output(response.outputs[0]))

REST streaming

import os
import httpx
from httpx_sse import connect_sse
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

with httpx.Client() as client:
    with connect_sse(client, "POST", "http://localhost:8080/v2/models/text-model/generate_stream", json=inference_request.dict()) as event_source:
        for sse in event_source.iter_sse():
            response = types.InferenceResponse.parse_raw(sse.data)
            print(StringCodec.decode_output(response.outputs[0]))

gRPC non-streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)
grpc_channel = grpc.insecure_channel("localhost:8081")
grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
response = grpc_stub.ModelInfer(inference_request_g)

response = ModelInferResponseConverter.to_types(response)
print(StringCodec.decode_output(response.outputs[0]))

gRPC streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter


TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)

async def get_inference_request_stream(inference_request):
    yield inference_request

async with grpc.aio.insecure_channel("localhost:8081") as grpc_channel:
    grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
    inference_request_stream = get_inference_request_stream(inference_request_g)
    
    async for response in grpc_stub.ModelStreamInfer(inference_request_stream):
        response = ModelInferResponseConverter.to_types(response)
        print(StringCodec.decode_output(response.outputs[0]))

Limitations

  • GZipMiddleware must be disabled since it is not compatible with starlette streaming ("gzip_enabled": false)
  • GRPC metrics endpoints must be disabled - further investigation in a following PR ("metrics_endpoint": null)
  • Parallel workers are not supported ("parallel_workers": 0)
  • Error handling for REST is not supported - this is because when the error raised is of asyncio.exceptions.CancelledError type. CancelledError inherits from BaseException and the starlette middleware for error handling checks for the type Exception.

@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from 32c3d7a to c2cf03c Compare May 10, 2024 09:28
@RobertSamoilescu RobertSamoilescu requested review from lc525 and sakoush May 10, 2024 09:58
@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from fb3eed6 to 3637aef Compare May 10, 2024 10:32
@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from a519933 to 63af613 Compare May 15, 2024 13:00
Copy link
Copy Markdown
Contributor

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it looks great, I left some comments though and I will look at tests next.

@@ -0,0 +1,45 @@
import asyncio
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add an example infer.py that uses this model? it is part of the PR description but probably better to have it here as well. Happy for it to be part of a follow up docs and examples PR.

payload: AsyncIterator[InferenceRequest],
) -> AsyncIterator[InferenceResponse]:
model = _get_model(f)
logger.warning(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to be logged on mlserver for every request? I think this might pollute the logs? I guess if the user doesnt set adaptive batching then this code path will not be hit anyway?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it outside.

break

payload = self._prepare_payload(payload, model)
payloads_decorated = self._payloads_decorator(payload, payloads, model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really a decorator logic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it.

payload = self._prepare_payload(payload, model)
payloads_decorated = self._payloads_decorator(payload, payloads, model)

async for prediction in model.predict_stream(payloads_decorated):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if one element in the stream fails? do we still keep going or should we break?

"""
async for inference_response in infer_stream:
# TODO: How should we send headers back?
# response_headers = extract_headers(inference_response)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the kind of headers we usually send back in the response of infer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See link here

Copy link
Copy Markdown
Contributor

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments on testing. I think we should add cases for:

  • errors on infer_stream
  • input streaming

@pytest.mark.parametrize(
"sum_model_settings", [lazy_fixture("text_stream_model_settings")]
)
@pytest.mark.parametrize("sum_model", [lazy_fixture("text_stream_model")])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is sum_model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a fixture (see here). Also the definition is here.



@pytest.mark.parametrize("settings", [lazy_fixture("settings_stream")])
@pytest.mark.parametrize(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is sum_model_settings?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a fixture (see here)

@pytest.mark.parametrize(
"model_name,model_version", [("text-model", "v1.2.3"), ("text-model", None)]
)
async def test_generate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can generate test be an extra parametrized item in infer test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit tricky to parameterise the model loading through lazy_fixtures due to recursive dependency involving fixture. I will leave it like this since I don't want to refactor the tests.



async def test_infer_error(rest_client, inference_request):
async def test_infer_error(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we tests errors for the stream case as well?

yield generate_request


async def test_predict_stream_fallback(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as explained earlier I am not sure if we should fallback to predict or raise not implemented.

@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from 63af613 to 499e693 Compare May 20, 2024 13:53
Copy link
Copy Markdown
Contributor

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - great work! This should be followed by a docs PR to describe streaming and the current limitations more explicitly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants