Skip to content

Commit 962056c

Browse files
authored
Lm_eval static generation improved (#2241)
1 parent f7a9a94 commit 962056c

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

examples/text-generation/model_adapter.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import argparse
2121
import logging
22-
from typing import Literal, Optional, Union
22+
from typing import Any, Literal, Optional, Union
2323

2424
import torch
2525
import torch.nn.functional as F
@@ -47,6 +47,7 @@ def __init__(
4747
logits_cache: bool = True,
4848
max_length: Optional[int] = None,
4949
softmax_dtype: Union[str, torch.dtype, None] = None,
50+
mixed_precision_dtype: Union[str, torch.dtype, None] = None,
5051
add_bos_token: Optional[bool] = True,
5152
prefix_token_id: Optional[int] = None,
5253
delta: Optional[str] = None,
@@ -86,8 +87,10 @@ def __init__(
8687
self.add_bos_token = add_bos_token
8788
self._max_length = max_length
8889
self.softmax_dtype = get_dtype(softmax_dtype) if softmax_dtype is not None else None
90+
self.mixed_precision_dtype = get_dtype(mixed_precision_dtype) if mixed_precision_dtype is not None else None
8991
self.hpu_graphs = args.use_hpu_graphs
9092
self.use_lazy_mode = True
93+
self.ignore_eos = args.ignore_eos
9194
if args.torch_compile:
9295
self.use_lazy_mode = False
9396
self.vocab_size = self._model.config.vocab_size
@@ -207,19 +210,23 @@ def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -
207210
self.max_length = legacy_max_length
208211
return res
209212

210-
def _model_generate(self, context, max_length, stop, **generation_kwargs):
213+
def _model_generate(
214+
self,
215+
context,
216+
max_length: int,
217+
stop: list[str],
218+
**generation_kwargs: dict[str, Any],
219+
) -> torch.Tensor:
211220
"""
212221
Patched method
213-
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L858
222+
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L951
214223
"""
215-
216224
# temperature = 0.0 if not set
217225
# if do_sample is false and temp==0.0:
218226
# remove temperature, as do_sample=False takes care of this
219227
# and we don't want a warning from HF
220228
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
221-
do_sample = generation_kwargs.get("do_sample", None)
222-
229+
do_sample = generation_kwargs.get("do_sample")
223230
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
224231
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
225232
generation_kwargs["do_sample"] = do_sample = False
@@ -231,18 +238,31 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
231238
# to avoid graph recompilation
232239
if self.options.static_shapes:
233240
self.options.bucket_internal = True
234-
_ = self.find_bucket(context.shape[1])
241+
bucket_length = self.find_bucket(context.shape[1])
242+
padding_length = bucket_length - context.shape[1]
235243
max_gen_toks = max_length - context.shape[1]
244+
if padding_length > 0 and self.hpu_graphs:
245+
# Static shapes require right-padding (left-padding due to batch encoding is performed at tok_batch_encode level)
246+
# See https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L869
247+
context = F.pad(context, (0, padding_length), value=self.tokenizer.pad_token_id)
248+
generation_kwargs["attention_mask"] = F.pad(
249+
generation_kwargs["attention_mask"], (0, padding_length), value=0
250+
)
236251
# move context & attention_mask to hpu
237252
context = context.to("hpu")
238253
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to("hpu")
239-
return self.model.generate(
240-
input_ids=context,
241-
max_new_tokens=max_gen_toks,
242-
stopping_criteria=stopping_criteria,
243-
pad_token_id=self.tokenizer.pad_token_id,
244-
use_cache=True,
245-
hpu_graphs=self.hpu_graphs,
246-
lazy_mode=self.use_lazy_mode,
247-
**generation_kwargs,
248-
)
254+
with torch.autocast(
255+
device_type="hpu",
256+
dtype=self.mixed_precision_dtype,
257+
enabled=self.mixed_precision_dtype is not None,
258+
):
259+
return self.model.generate(
260+
input_ids=context,
261+
max_new_tokens=max_gen_toks,
262+
stopping_criteria=stopping_criteria,
263+
pad_token_id=self.tokenizer.pad_token_id,
264+
use_cache=True,
265+
hpu_graphs=self.hpu_graphs,
266+
lazy_mode=self.use_lazy_mode,
267+
**generation_kwargs,
268+
)

examples/text-generation/run_lm_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def setup_lm_eval_parser():
7878
type=int,
7979
nargs="+",
8080
help="Input length buckets to use with static_shapes",
81-
default=[16, 32, 64, 128, 189, 284, 384, 985],
81+
default=[16, 32, 64, 128, 189, 284, 384],
8282
)
8383

8484
parser.add_argument(

0 commit comments

Comments
 (0)