Skip to content

Commit 766a0cd

Browse files
committed
add tests
1 parent 6775d38 commit 766a0cd

33 files changed

+2429
-287
lines changed

.github/workflows/tests-my.yml

Lines changed: 0 additions & 68 deletions
This file was deleted.

runtimes/huggingface/mlserver_huggingface/codecs/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
from .image import PILImageCodec
33
from .json import HuggingfaceSingleJSONCodec
44
from .jsonlist import HuggingfaceListJSONCodec
5+
from .numpylist import NumpyListCodec
6+
from .conversation import HuggingfaceConversationCodec
7+
from .raw import RawCodec
8+
from .utils import EqualUtil
59

610
__all__ = [
7-
MultiInputRequestCodec,
8-
HuggingfaceRequestCodec,
9-
PILImageCodec,
10-
HuggingfaceSingleJSONCodec,
11-
HuggingfaceListJSONCodec,
11+
"MultiInputRequestCodec",
12+
"HuggingfaceRequestCodec",
13+
"PILImageCodec",
14+
"HuggingfaceSingleJSONCodec",
15+
"HuggingfaceListJSONCodec",
16+
"HuggingfaceConversationCodec",
17+
"NumpyListCodec",
18+
"RawCodec",
19+
"EqualUtil",
1220
]

runtimes/huggingface/mlserver_huggingface/codecs/base.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Type, Any, Dict, AnyStr, List, Union
1+
from typing import Optional, Type, Any, Dict, List, Union
22
from mlserver.codecs.utils import (
33
has_decoded,
44
_save_decoded,
@@ -40,38 +40,39 @@ class MultiInputRequestCodec(RequestCodec):
4040
Huggingface codecs is prefered, then mlserver's
4141
"""
4242

43-
ContentType: str = StringCodec.ContentType
44-
DefaultCodec: Type[InputCodecTy] = StringCodec
43+
DefaultCodec: Type["InputCodecTy"] = StringCodec
4544
InputCodecsWithPriority: List[Type[InputCodecTy]] = []
45+
ContentType = StringCodec.ContentType
4646

4747
@classmethod
4848
def _find_encode_codecs(
49-
cls, payload: Dict[AnyStr, Any]
50-
) -> Dict[str, Union[Type[InputCodecTy], None]]:
51-
field_codec = {}
49+
cls, payload: Dict[str, Any]
50+
) -> Dict[str, Union[Type["InputCodecTy"], "InputCodecTy", None]]:
51+
field_codec: Dict[str, Union[Type["InputCodecTy"], "InputCodecTy", None]] = {}
5252
for field, value in payload.items():
5353
for codec in cls.InputCodecsWithPriority:
5454
if codec.can_encode(value):
5555
field_codec[field] = codec
5656
break
5757
if field not in field_codec:
5858
field_codec[field] = find_input_codec_by_payload(value)
59+
5960
return field_codec
6061

6162
@classmethod
6263
def _find_decode_codecs(
6364
cls, data: Union[InferenceResponse, InferenceRequest]
64-
) -> Dict[str, Union[Type[InputCodecTy], None]]:
65+
) -> Dict[str, Union[Type[InputCodecTy], InputCodecTy, None]]:
6566
field_codec = {}
66-
fields = []
67+
fields = [] # type: ignore
6768
if data.parameters:
6869
default_codec = find_input_codec(data.parameters.content_type)
6970
else:
7071
default_codec = cls.DefaultCodec
7172
if isinstance(data, InferenceRequest):
7273
fields = data.inputs
7374
else:
74-
fields = data.outputs
75+
fields = data.outputs # type: ignore
7576
for field in fields:
7677
if not field.parameters:
7778
field_codec[field.name] = default_codec
@@ -87,12 +88,12 @@ def _find_decode_codecs(
8788
return field_codec
8889

8990
@classmethod
90-
def _can_encode_request(cls, payload: Dict[AnyStr, Any]) -> bool:
91+
def _can_encode_request(cls, payload: Dict[str, Any]) -> bool:
9192
field_codecs = cls._find_encode_codecs(payload)
9293
return bool(all(field_codecs.values()))
9394

9495
@classmethod
95-
def can_encode(cls, payload: Dict[AnyStr, Any]) -> bool:
96+
def can_encode(cls, payload: Dict[str, Any]) -> bool:
9697
"""
9798
Inputs always is Dict, Outputs always is list
9899
"""
@@ -128,17 +129,21 @@ def encode_response(
128129
)
129130

130131
@classmethod
131-
def decode_response(cls, response: InferenceResponse) -> List[Any]:
132+
def decode_response(
133+
cls, response: InferenceResponse
134+
) -> Union[List[Any], Dict[Any, Any]]:
132135
"""
133136
Always use HuggingfaceJSONCodec
134137
"""
135138
data = {}
136139
is_list = True
137140
field_codecs = cls._find_decode_codecs(response)
138141
for item in response.outputs:
139-
if not has_decoded(item) and field_codecs.get(item.name):
140-
decoded_payload = field_codecs[item.name].decode_input(item)
141-
_save_decoded(item, decoded_payload)
142+
if not has_decoded(item):
143+
codec = field_codecs[item.name]
144+
if codec is not None:
145+
decoded_payload = codec.decode_output(item)
146+
_save_decoded(item, decoded_payload)
142147

143148
value = get_decoded_or_raw(item)
144149
data[item.name] = value
@@ -154,6 +159,10 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:
154159
inputs = []
155160
for key, value in payload.items():
156161
codec = field_codecs[key]
162+
if codec is None:
163+
raise Exception(
164+
f"codec for key {key} value not found, value is {value}"
165+
)
157166
input_v = codec.encode_input(key, value, **kwargs)
158167
set_content_type(input_v, codec.ContentType)
159168
inputs.append(input_v)
@@ -169,9 +178,11 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
169178
values = {}
170179
field_codecs = cls._find_decode_codecs(request)
171180
for item in request.inputs:
172-
if not has_decoded(item) and field_codecs.get(item.name):
173-
decoded_payload = field_codecs[item.name].decode_input(item)
174-
_save_decoded(item, decoded_payload)
181+
if not has_decoded(item):
182+
codec = field_codecs[item.name]
183+
if codec is not None:
184+
decoded_payload = codec.decode_input(item)
185+
_save_decoded(item, decoded_payload)
175186

176187
value = get_decoded_or_raw(item)
177188
values[item.name] = value

runtimes/huggingface/mlserver_huggingface/codecs/conversation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Any, Dict
1+
from typing import List, Any
22
from mlserver.codecs.base import InputCodec, register_input_codec
33
from mlserver.types import RequestInput, ResponseOutput, Parameters
44
from transformers.pipelines import Conversation
@@ -22,7 +22,7 @@ def can_encode(cls, payload: Any) -> bool:
2222
def encode_output(
2323
cls, name: str, payload: List[Conversation], use_bytes: bool = True, **kwargs
2424
) -> ResponseOutput:
25-
encoded = [json_encode(item) for item in payload]
25+
encoded = [json_encode(item, use_bytes=use_bytes) for item in payload]
2626
shape = [len(encoded), 1]
2727
return ResponseOutput(
2828
name=name,
@@ -35,7 +35,7 @@ def encode_output(
3535
)
3636

3737
@classmethod
38-
def decode_output(cls, response_output: ResponseOutput) -> Dict[Any, Any]:
38+
def decode_output(cls, response_output: ResponseOutput) -> List[Any]:
3939
packed = response_output.data.__root__
4040
return [json_decode(item) for item in packed]
4141

runtimes/huggingface/mlserver_huggingface/codecs/image.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import io
22
import base64
3-
from typing import List, Any
3+
from typing import List, Any, Union
44
from PIL import Image
55
from mlserver.codecs.base import InputCodec, register_input_codec
66
from mlserver.codecs.lists import as_list, is_list_of
77
from mlserver.types import RequestInput, ResponseOutput, Parameters
88
from functools import partial
99

1010

11-
def _pil_base64encode(img: "Image.Image", use_bytes: bool = False) -> bytes:
11+
def _pil_base64encode(img: "Image.Image", use_bytes: bool = False) -> Union[bytes, str]:
1212
buf = io.BytesIO()
13-
img.save(buf, format="png")
13+
img.save(buf, format=img.format)
1414
if use_bytes:
1515
return base64.b64encode(buf.getvalue())
1616
return base64.b64encode(buf.getvalue()).decode()
1717

1818

19-
def _pil_base64decode(imgbytes: bytes) -> "Image.Image":
19+
def _pil_base64decode(imgbytes: Union[bytes, str]) -> "Image.Image":
20+
if isinstance(imgbytes, bytes):
21+
imgbytes = imgbytes.decode()
2022
buf = io.BytesIO(base64.b64decode(imgbytes))
2123
return Image.open(buf)
2224

runtimes/huggingface/mlserver_huggingface/codecs/json.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ def encode_output(
2121
cls, name: str, payload: Dict[Any, Any], use_bytes: bool = True, **kwargs
2222
) -> ResponseOutput:
2323
encoded = json_encode(payload, use_bytes)
24-
shape = [len(encoded), 1]
2524
return ResponseOutput(
2625
name=name,
2726
parameters=Parameters(
2827
content_type=cls.ContentType,
2928
),
3029
datatype="BYTES",
31-
shape=shape,
30+
shape=[1],
3231
data=[encoded],
3332
)
3433

runtimes/huggingface/mlserver_huggingface/codecs/jsonlist.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from mlserver.codecs.base import InputCodec, register_input_codec
33
from mlserver.types import RequestInput, ResponseOutput, Parameters
44
from mlserver.codecs.lists import is_list_of
5+
from functools import partial
56
from .utils import json_decode, json_encode
67

78

@@ -21,7 +22,7 @@ def can_encode(cls, payload: Any) -> bool:
2122
def encode_output(
2223
cls, name: str, payload: List[Dict[Any, Any]], use_bytes: bool = True, **kwargs
2324
) -> ResponseOutput:
24-
packed = map(json_encode, payload)
25+
packed = map(partial(json_encode, use_bytes=use_bytes), payload)
2526
shape = [len(payload), 1]
2627
return ResponseOutput(
2728
name=name,
@@ -40,7 +41,7 @@ def decode_output(cls, response_output: ResponseOutput) -> List[Dict[Any, Any]]:
4041

4142
@classmethod
4243
def encode_input(
43-
cls, name: str, payload: Dict[Any, Any], use_bytes: bool = True, **kwargs
44+
cls, name: str, payload: List[Dict[Any, Any]], use_bytes: bool = True, **kwargs
4445
) -> RequestInput:
4546
output = cls.encode_output(name, payload, use_bytes)
4647
return RequestInput(

runtimes/huggingface/mlserver_huggingface/codecs/numpylist.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ class NumpyListCodec(InputCodec):
2020

2121
@classmethod
2222
def can_encode(csl, payload: Any) -> bool:
23-
return is_list_of(payload, np.ndarray)
23+
if not is_list_of(payload, np.ndarray):
24+
return False
25+
# only the support same shaped ndarray
26+
return len(set([matrix.shape for matrix in payload])) == 1
2427

2528
@classmethod
2629
def encode_output(
@@ -36,6 +39,7 @@ def encode_output(
3639
datatype=datatype,
3740
shape=shape,
3841
data=_encode_data(composed, datatype),
42+
parameters=Parameters(content_type=cls.ContentType),
3943
)
4044

4145
@classmethod
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Union
2+
from mlserver.types import RequestInput, ResponseOutput, Parameters
3+
from mlserver.codecs.base import InputCodec, register_input_codec
4+
5+
6+
@register_input_codec
7+
class RawCodec(InputCodec):
8+
"""
9+
Encode/Decode raw python datatypes
10+
"""
11+
12+
ContentType = "raw"
13+
14+
@classmethod
15+
def can_encode(cls, payload: Any) -> bool:
16+
return (
17+
isinstance(payload, int)
18+
or isinstance(payload, str)
19+
or isinstance(payload, float)
20+
)
21+
22+
@classmethod
23+
def encode_output(
24+
cls, name: str, payload: Union[int, str, float], **kwargs
25+
) -> ResponseOutput:
26+
return ResponseOutput(
27+
name=name,
28+
datatype="BYTES",
29+
shape=[1],
30+
data=[payload],
31+
parameters=Parameters(content_type=cls.ContentType),
32+
)
33+
34+
@classmethod
35+
def decode_output(cls, response_output: ResponseOutput) -> Union[int, str, float]:
36+
return cls.decode_input(response_output) # type: ignore
37+
38+
@classmethod
39+
def encode_input(
40+
cls, name: str, payload: Union[int, str, float], **kwargs
41+
) -> RequestInput:
42+
output = cls.encode_output(name=name, payload=payload)
43+
44+
return RequestInput(
45+
name=output.name,
46+
datatype=output.datatype,
47+
shape=output.shape,
48+
data=output.data,
49+
parameters=output.parameters,
50+
)
51+
52+
@classmethod
53+
def decode_input(cls, request_input: RequestInput) -> Union[int, str, float]:
54+
return request_input.data[0]

0 commit comments

Comments
 (0)