Skip to content

Add flex attention flags and args to Llama on Habana#2246

Merged
regisss merged 4 commits intohuggingface:mainfrom
AKloniecki:flex_attention
Sep 11, 2025
Merged

Add flex attention flags and args to Llama on Habana#2246
regisss merged 4 commits intohuggingface:mainfrom
AKloniecki:flex_attention

Conversation

@AKloniecki
Copy link
Collaborator

What does this PR do?

This is a cherry-pick of PR #2124 with tests added.

Flex attention is an attention function, alternative to flash attention.

Example usage:
Consider current command using flash attention:
PT_HPU_LAZY_MODE=1 python3 run_generation.py --model_name_or_path meta-llama/Llama-3.1-8B-Instruct --attn_softmax_bf16 --use_hpu_graphs --limit_hpu_graphs --use_kv_cache --max_new_tokens 8192 --bf16 --batch_size 100 --use_flash_attention --flash_attention_recompute --bucket_size=128 --bucket_internal --trim_logits --max_input_tokens 128 --warmup 2

it may be computed using flex attention as follows:
PT_HPU_LAZY_MODE=1 python3 run_generation.py --model_name_or_path meta-llama/Llama-3.1-8B-Instruct --attn_softmax_bf16 --use_hpu_graphs --limit_hpu_graphs --use_kv_cache --max_new_tokens 8192 --bf16 --batch_size 100 --use_flex_attention --bucket_size=128 --bucket_internal --trim_logits --max_input_tokens 128 --warmup 2

@astachowiczhabana
Copy link
Collaborator

@AKloniecki please check

ruff check . setup.py
All checks passed!
ruff format --check . setup.py
Would reformat: tests/test_text_generation_example.py
1 file would be reformatted, 468 files already formatted
make: *** [Makefile:29: style_check] Error 1
test test_methods failed to style_check with exit status 2
failed executing optimum-habana-fork tests
~

kareemshaik80 and others added 2 commits September 9, 2025 10:30
Signed-off-by: kareem <kshaik@habana.ai>
Signed-off-by: Artur Kloniecki <aklonieckix@habana.ai>
Artur KlonieckiX added 2 commits September 9, 2025 11:50
…tor.

Signed-off-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
Signed-off-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
@AKloniecki
Copy link
Collaborator Author

CI has turned green.
@astachowiczhabana Please review and proceed with merge

@astachowiczhabana
Copy link
Collaborator

Hi @regisss
Can you please merge?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@regisss regisss merged commit bab72df into huggingface:main Sep 11, 2025
3 of 5 checks passed
astachowiczhabana pushed a commit that referenced this pull request Sep 11, 2025
Signed-off-by: kareem <kshaik@habana.ai>
Signed-off-by: Artur Kloniecki <aklonieckix@habana.ai>
Signed-off-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
Co-authored-by: kareem <kshaik@habana.ai>
Co-authored-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
@imangohari1
Copy link
Contributor

@astachowiczhabana @AKloniecki @regisss
this PR is causing the failures in conftest.py. Reverting the changes in the tests/test_text_generation_example.py passes it. Need more investigation.

$ PT_HPU_LAZY_MODE=1  RUN_SLOW=true python -m pytest tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-2-9b-1-False-True-False] -s -v --token XXX
.
.
.
.
09/11/2025 19:24:52 - INFO - __main__ - Time to rest of tokens = 9.363325838439962ms
09/11/2025 19:24:52 - INFO - __main__ - End to end latency = 1077.9737390039372ms
09/11/2025 19:24:53 - INFO - __main__ - Time to first token = 15.506711999478284ms
09/11/2025 19:24:53 - INFO - __main__ - Time to rest of tokens = 9.361828282837651ms
09/11/2025 19:24:53 - INFO - __main__ - End to end latency = 1077.886902996397ms
09/11/2025 19:24:53 - INFO - __main__ - Finished running generate

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1.1: ('DeepSpeed is a machine learning framework that enables training of large-scale deep learning models on a single GPU or across multiple GPUs. It is designed to be easy to use and highly scalable, making it a popular choice for training large-scale models such as GPT-3 and BERT.\n\nDeepSpeed is built on top of PyTorch, a popular deep learning framework, and provides a set of tools and libraries that make it easy to train large-scale models. It includes features such as zero-shot learning, which allows models to',)


Stats:
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 92.73568232505258 tokens/second
Average first token latency         = 15.54601880197879 ms
Average rest token latency          = 9.3620312161746 ms
Average end to end latency          = 1077.9293371990207 ms
Memory allocated                    = 19.45 GB
Max memory allocated                = 19.46 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 3.9837110380030936 seconds
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

FAILED

============================================================================================================================================================== FAILURES ==============================================================================================================================================================
_________________________________________________________________________________________________________________________________ test_text_generation_bf16_1x[google/gemma-2-9b-1-False-True-False] _________________________________________________________________________________________________________________________________

model_name = 'google/gemma-2-9b', batch_size = 1, reuse_cache = False, check_output = True, use_flex_attention = False, baseline = <conftest.BaselineRequest object at 0x7f78cec6a1a0>, token = Secret(********)

    @pytest.mark.parametrize(
        "model_name, batch_size, reuse_cache, check_output, use_flex_attention", MODELS_TO_TEST["bf16_1x"]
    )
    def test_text_generation_bf16_1x(
        model_name: str, batch_size: int, reuse_cache: bool, check_output: bool, use_flex_attention: bool, baseline, token
    ):
>       _test_text_generation(
            model_name=model_name,
            baseline=baseline,
            token=token,
            batch_size=batch_size,
            reuse_cache=reuse_cache,
            check_output=check_output,
            use_flex_attention=use_flex_attention,
        )

tests/test_text_generation_example.py:425: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/test_text_generation_example.py:404: in _test_text_generation
    baseline.assertRef(
conftest.py:75: in assertRef
    assert compare(actual, ref)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

actual = 92.73568232505258, ref = None

>       compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref,
        context=[OH_DEVICE_CONTEXT],
        throughput=results["throughput"],
    )
E   TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'

tests/test_text_generation_example.py:405: TypeError
====================================================================================================================================================== short test summary info =======================================================================================================================================================
FAILED tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-2-9b-1-False-True-False] - TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'
========================================================================================================================================================= 1 failed in 31.15s =========================================================================================================================================================

@AKloniecki AKloniecki deleted the flex_attention branch September 12, 2025 08:34
astachowiczhabana pushed a commit that referenced this pull request Sep 17, 2025
Signed-off-by: kareem <kshaik@habana.ai>
Signed-off-by: Artur Kloniecki <aklonieckix@habana.ai>
Signed-off-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
Co-authored-by: kareem <kshaik@habana.ai>
Co-authored-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
gplutop7 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Oct 15, 2025
… (huggingface#663)

Signed-off-by: kareem <kshaik@habana.ai>
Signed-off-by: Artur Kloniecki <aklonieckix@habana.ai>
Signed-off-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
Co-authored-by: Artur KlonieckiX <aklonieckix@habana.ai>
Co-authored-by: kareem <kshaik@habana.ai>
Co-authored-by: Artur KlonieckiX <akloniex@akloniex-vm-u24.habana-labs.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants