1919
2020import argparse
2121import logging
22- from typing import Literal , Optional , Union
22+ from typing import Any , Literal , Optional , Union
2323
2424import torch
2525import torch .nn .functional as F
@@ -47,6 +47,7 @@ def __init__(
4747 logits_cache : bool = True ,
4848 max_length : Optional [int ] = None ,
4949 softmax_dtype : Union [str , torch .dtype , None ] = None ,
50+ mixed_precision_dtype : Union [str , torch .dtype , None ] = None ,
5051 add_bos_token : Optional [bool ] = True ,
5152 prefix_token_id : Optional [int ] = None ,
5253 delta : Optional [str ] = None ,
@@ -86,8 +87,10 @@ def __init__(
8687 self .add_bos_token = add_bos_token
8788 self ._max_length = max_length
8889 self .softmax_dtype = get_dtype (softmax_dtype ) if softmax_dtype is not None else None
90+ self .mixed_precision_dtype = get_dtype (mixed_precision_dtype ) if mixed_precision_dtype is not None else None
8991 self .hpu_graphs = args .use_hpu_graphs
9092 self .use_lazy_mode = True
93+ self .ignore_eos = args .ignore_eos
9194 if args .torch_compile :
9295 self .use_lazy_mode = False
9396 self .vocab_size = self ._model .config .vocab_size
@@ -207,19 +210,23 @@ def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -
207210 self .max_length = legacy_max_length
208211 return res
209212
210- def _model_generate (self , context , max_length , stop , ** generation_kwargs ):
213+ def _model_generate (
214+ self ,
215+ context ,
216+ max_length : int ,
217+ stop : list [str ],
218+ ** generation_kwargs : dict [str , Any ],
219+ ) -> torch .Tensor :
211220 """
212221 Patched method
213- source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7 /lm_eval/models/huggingface.py/#L858
222+ source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1 /lm_eval/models/huggingface.py#L951
214223 """
215-
216224 # temperature = 0.0 if not set
217225 # if do_sample is false and temp==0.0:
218226 # remove temperature, as do_sample=False takes care of this
219227 # and we don't want a warning from HF
220228 generation_kwargs ["temperature" ] = generation_kwargs .get ("temperature" , 0.0 )
221- do_sample = generation_kwargs .get ("do_sample" , None )
222-
229+ do_sample = generation_kwargs .get ("do_sample" )
223230 # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
224231 if generation_kwargs .get ("temperature" ) == 0.0 and do_sample is None :
225232 generation_kwargs ["do_sample" ] = do_sample = False
@@ -231,18 +238,31 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
231238 # to avoid graph recompilation
232239 if self .options .static_shapes :
233240 self .options .bucket_internal = True
234- _ = self .find_bucket (context .shape [1 ])
241+ bucket_length = self .find_bucket (context .shape [1 ])
242+ padding_length = bucket_length - context .shape [1 ]
235243 max_gen_toks = max_length - context .shape [1 ]
244+ if padding_length > 0 and self .hpu_graphs :
245+ # Static shapes require right-padding (left-padding due to batch encoding is performed at tok_batch_encode level)
246+ # See https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/models/huggingface.py#L869
247+ context = F .pad (context , (0 , padding_length ), value = self .tokenizer .pad_token_id )
248+ generation_kwargs ["attention_mask" ] = F .pad (
249+ generation_kwargs ["attention_mask" ], (0 , padding_length ), value = 0
250+ )
236251 # move context & attention_mask to hpu
237252 context = context .to ("hpu" )
238253 generation_kwargs ["attention_mask" ] = generation_kwargs ["attention_mask" ].to ("hpu" )
239- return self .model .generate (
240- input_ids = context ,
241- max_new_tokens = max_gen_toks ,
242- stopping_criteria = stopping_criteria ,
243- pad_token_id = self .tokenizer .pad_token_id ,
244- use_cache = True ,
245- hpu_graphs = self .hpu_graphs ,
246- lazy_mode = self .use_lazy_mode ,
247- ** generation_kwargs ,
248- )
254+ with torch .autocast (
255+ device_type = "hpu" ,
256+ dtype = self .mixed_precision_dtype ,
257+ enabled = self .mixed_precision_dtype is not None ,
258+ ):
259+ return self .model .generate (
260+ input_ids = context ,
261+ max_new_tokens = max_gen_toks ,
262+ stopping_criteria = stopping_criteria ,
263+ pad_token_id = self .tokenizer .pad_token_id ,
264+ use_cache = True ,
265+ hpu_graphs = self .hpu_graphs ,
266+ lazy_mode = self .use_lazy_mode ,
267+ ** generation_kwargs ,
268+ )
0 commit comments