Skip to content

Lm_eval static generation improved#2241

Merged
regisss merged 7 commits intohuggingface:mainfrom
12010486:lm_eval_static_generation
Sep 11, 2025
Merged

Lm_eval static generation improved#2241
regisss merged 7 commits intohuggingface:mainfrom
12010486:lm_eval_static_generation

Conversation

@12010486
Copy link
Collaborator

@12010486 12010486 commented Sep 3, 2025

Main changes as below:

Precision and device support improvements

  • Added a mixed_precision_dtype parameter to the ModelAdapter class, allowing users to specify the desired precision for HPU autocasting. The torch.autocast context is now used during generation, improving performance and memory usage on HPUs.

Static shape and padding logic

  • Improved static shape bucket management: the code now calculates the required left-padding for input contexts to fit the selected bucket size, ensuring correct input shapes for HPUs. This includes padding both the context and the attention mask.

Argument and API updates

  • Updated the default input length buckets in run_lm_eval.py to remove the largest value. It might caused mismatch with previous legacy results
  • Added new arguments (ignore_eos) and updated references for compatibility with the latest version of the underlying evaluation harness.

Note: ignore_eos is passed but not used, as it is decreasing accuracy results.

How to test the impact

The command below was producing an OOM error, now on G2 memory is ~ 27 GB

PT_HPU_LAZY_MODE=1 python run_lm_eval.py --model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--attn_softmax_bf16 --use_hpu_graphs --limit_hpu_graphs   --use_kv_cache --bf16 --sdp_on_bf16 --trim_logits \
--batch_size=4 --tasks gsm8k_cot_llama -o eval_gsm8k.json --num_fewshot=8 \
--fewshot_as_multiturn --apply_chat_template True 

@12010486 12010486 requested a review from regisss as a code owner September 3, 2025 13:47
@12010486 12010486 requested review from astachowiczhabana and removed request for regisss September 3, 2025 13:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@12010486
Copy link
Collaborator Author

12010486 commented Sep 3, 2025

Checking this PR on other lm tasks I've seen a drop wrt to the numbers I was having before (on mbpp or humaneval, for example) so converting it to draft while I investigate more

@12010486 12010486 marked this pull request as draft September 3, 2025 16:31
@12010486 12010486 marked this pull request as ready for review September 5, 2025 14:23
bucket_length = self.find_bucket(context.shape[1])
padding_length = bucket_length - context.shape[1]
max_gen_toks = max_length - context.shape[1]
if padding_length > 0 and self.hpu_graphs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @12010486
This code is to counter-effect tokenizer with hardcoded padding right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we could have done it also modifying this function https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L858 but I wanted to avoid another function to patch

nargs="+",
help="Input length buckets to use with static_shapes",
default=[16, 32, 64, 128, 189, 284, 384, 985],
default=[16, 32, 64, 128, 189, 284, 384],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but thanks for double checking. I've introduced it because on v1.19 there was this commit, for granite accuracy, so I wanted to be sure not to introduce back a known regression. 0222c48

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@regisss regisss merged commit 962056c into huggingface:main Sep 11, 2025
4 of 9 checks passed
gplutop7 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Oct 15, 2025
Co-authored-by: Silvia Colabrese <silvia.colabrese@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants