Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions examples/text-generation/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import argparse
import logging
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
logits_cache: bool = True,
max_length: Optional[int] = None,
softmax_dtype: Union[str, torch.dtype, None] = None,
mixed_precision_dtype: Union[str, torch.dtype, None] = None,
add_bos_token: Optional[bool] = True,
prefix_token_id: Optional[int] = None,
delta: Optional[str] = None,
Expand Down Expand Up @@ -86,8 +87,10 @@ def __init__(
self.add_bos_token = add_bos_token
self._max_length = max_length
self.softmax_dtype = get_dtype(softmax_dtype) if softmax_dtype is not None else None
self.mixed_precision_dtype = get_dtype(mixed_precision_dtype) if mixed_precision_dtype is not None else None
self.hpu_graphs = args.use_hpu_graphs
self.use_lazy_mode = True
self.ignore_eos = args.ignore_eos
if args.torch_compile:
self.use_lazy_mode = False
self.vocab_size = self._model.config.vocab_size
Expand Down Expand Up @@ -205,19 +208,23 @@ def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -
self.max_length = legacy_max_length
return res

def _model_generate(self, context, max_length, stop, **generation_kwargs):
def _model_generate(
self,
context,
max_length: int,
stop: list[str],
**generation_kwargs: dict[str, Any],
) -> torch.Tensor:
"""
Patched method
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L858
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L951
"""

# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)

do_sample = generation_kwargs.get("do_sample")
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
Expand All @@ -229,18 +236,31 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
# to avoid graph recompilation
if self.options.static_shapes:
self.options.bucket_internal = True
_ = self.find_bucket(context.shape[1])
bucket_length = self.find_bucket(context.shape[1])
padding_length = bucket_length - context.shape[1]
max_gen_toks = max_length - context.shape[1]
if padding_length > 0 and self.hpu_graphs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @12010486
This code is to counter-effect tokenizer with hardcoded padding right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we could have done it also modifying this function https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L858 but I wanted to avoid another function to patch

# Static shapes require right-padding (left-padding due to batch encoding is performed at tok_batch_encode level)
# See https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L869
context = F.pad(context, (0, padding_length), value=self.tokenizer.pad_token_id)
generation_kwargs["attention_mask"] = F.pad(
generation_kwargs["attention_mask"], (0, padding_length), value=0
)
# move context & attention_mask to hpu
context = context.to("hpu")
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to("hpu")
return self.model.generate(
input_ids=context,
max_new_tokens=max_gen_toks,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
hpu_graphs=self.hpu_graphs,
lazy_mode=self.use_lazy_mode,
**generation_kwargs,
)
with torch.autocast(
device_type="hpu",
dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None,
):
return self.model.generate(
input_ids=context,
max_new_tokens=max_gen_toks,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
hpu_graphs=self.hpu_graphs,
lazy_mode=self.use_lazy_mode,
**generation_kwargs,
)
2 changes: 1 addition & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setup_lm_eval_parser():
type=int,
nargs="+",
help="Input length buckets to use with static_shapes",
default=[16, 32, 64, 128, 189, 284, 384, 985],
default=[16, 32, 64, 128, 189, 284, 384],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but thanks for double checking. I've introduced it because on v1.19 there was this commit, for granite accuracy, so I wanted to be sure not to introduce back a known regression. 0222c48

)

parser.add_argument(
Expand Down
Loading