Skip to content

Commit aa98111

Browse files
AKlonieckiArtur KlonieckiX
authored andcommitted
Add tests for flex attention in text generation.
Signed-off-by: Artur Kloniecki <aklonieckix@habana.ai>
1 parent 355e640 commit aa98111

File tree

2 files changed

+112
-83
lines changed

2 files changed

+112
-83
lines changed

examples/text-generation/model_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def __init__(
142142
)
143143
if self.model.config.model_type in ["llama", "qwen2", "baichuan", "gpt_bigcode"]:
144144
self.model_inputs.update({"flash_attention_fast_softmax": self.options.flash_attention_fast_softmax})
145+
if self.model_config.model_type in ["llama"]:
146+
self.model_inputs.update({"use_flex_attention": self.options.use_flex_attention})
145147
if args.warmup:
146148
self.warm_up()
147149

tests/test_text_generation_example.py

Lines changed: 110 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -23,68 +23,82 @@
2323
# Gaudi2+
2424
MODELS_TO_TEST = {
2525
"bf16_1x": [
26-
("bigscience/bloomz-7b1", 1, False, False),
27-
("gpt2-xl", 1, False, False),
28-
pytest.param("EleutherAI/gpt-j-6b", 1, False, False, marks=pytest.mark.skip("Deprecated in v1.20")),
29-
("EleutherAI/gpt-neox-20b", 1, False, False),
30-
("meta-llama/Llama-2-7b-hf", 1, True, True),
31-
("tiiuae/falcon-40b", 1, True, False),
32-
("bigcode/starcoder", 256, True, True),
33-
pytest.param("Salesforce/codegen2-1B", 1, False, False, marks=pytest.mark.skip("Deprecated")),
34-
("mosaicml/mpt-30b", 1, False, False),
35-
("mistralai/Mistral-7B-v0.1", 1, True, True),
36-
("mistralai/Mixtral-8x7B-v0.1", 1, False, True),
37-
("microsoft/phi-2", 1, False, False),
38-
("meta-llama/Meta-Llama-3-8B", 1, True, False),
39-
("meta-llama/Llama-2-7b-hf", 512, True, False),
40-
("meta-llama/Llama-2-7b-hf", 512, False, False), # in some cases like TGI, reuse_cache isn't used
41-
("stabilityai/stablelm-2-12b", 1, False, False),
42-
("codellama/CodeLlama-34b-hf", 1, True, False),
43-
("bigcode/starcoder2-3b", 1, False, True),
44-
("adept/persimmon-8b-base", 4, False, False),
45-
# ("Qwen/Qwen1.5-7B", 4, False, False),
46-
("google/gemma-7b", 1, False, True),
47-
("google/gemma-2-9b", 1, False, True),
48-
("google/gemma-2-27b", 1, False, True),
49-
pytest.param("state-spaces/mamba-130m-hf", 1536, False, False, marks=pytest.mark.skip("Deprecated")),
50-
# ("Deci/DeciLM-7B", 1, False, False),
51-
("Qwen/Qwen2-7B", 256, False, True),
52-
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False),
53-
# ("EleutherAI/gpt-neo-2.7B", 1, False, False),
54-
# ("facebook/xglm-1.7B", 1, False, False),
55-
# ("CohereForAI/c4ai-command-r-v01", 1, False, False),
56-
("tiiuae/falcon-mamba-7b", 1, False, False),
57-
("openbmb/MiniCPM3-4B", 1, False, False),
58-
("baichuan-inc/Baichuan2-7B-Chat", 1, True, False),
59-
("baichuan-inc/Baichuan2-13B-Chat", 1, False, False),
60-
("deepseek-ai/DeepSeek-V2-Lite", 1, False, False),
61-
("THUDM/chatglm2-6b", 1, True, False),
62-
("THUDM/chatglm3-6b", 1, True, False),
63-
("Qwen/Qwen2.5-7B", 4, False, False),
64-
("moonshotai/Moonlight-16B-A3B", 1, False, False),
65-
("Qwen/Qwen3-8B", 1, False, False),
66-
("Qwen/Qwen3-30B-A3B", 1, False, False),
26+
("bigscience/bloomz-7b1", 1, False, False, False),
27+
("gpt2-xl", 1, False, False, False),
28+
pytest.param("EleutherAI/gpt-j-6b", 1, False, False, False, marks=pytest.mark.skip("Deprecated in v1.20")),
29+
("EleutherAI/gpt-neox-20b", 1, False, False, False),
30+
("meta-llama/Llama-2-7b-hf", 1, True, True, False),
31+
("meta-llama/Llama-2-7b-hf", 1, True, True, True),
32+
("tiiuae/falcon-40b", 1, True, False, False),
33+
("bigcode/starcoder", 256, True, True, False),
34+
pytest.param("Salesforce/codegen2-1B", 1, False, False, False, marks=pytest.mark.skip("Deprecated")),
35+
("mosaicml/mpt-30b", 1, False, False, False),
36+
("mistralai/Mistral-7B-v0.1", 1, True, True, False),
37+
("mistralai/Mixtral-8x7B-v0.1", 1, False, True, False),
38+
("microsoft/phi-2", 1, False, False, False),
39+
("meta-llama/Meta-Llama-3-8B", 1, True, False, False),
40+
("meta-llama/Meta-Llama-3-8B", 1, True, False, True),
41+
("meta-llama/Llama-2-7b-hf", 512, True, False, False),
42+
("meta-llama/Llama-2-7b-hf", 512, True, False, True),
43+
("meta-llama/Llama-2-7b-hf", 512, False, False, False), # in some cases like TGI, reuse_cache isn't used
44+
("meta-llama/Llama-2-7b-hf", 512, False, False, True), # in some cases like TGI, reuse_cache isn't used
45+
("stabilityai/stablelm-2-12b", 1, False, False, False),
46+
("codellama/CodeLlama-34b-hf", 1, True, False, False),
47+
("codellama/CodeLlama-34b-hf", 1, True, False, True),
48+
("bigcode/starcoder2-3b", 1, False, True, False),
49+
("adept/persimmon-8b-base", 4, False, False, False),
50+
# ("Qwen/Qwen1.5-7B", 4, False, False, False),
51+
("google/gemma-7b", 1, False, True, False),
52+
("google/gemma-2-9b", 1, False, True, False),
53+
("google/gemma-2-27b", 1, False, True, False),
54+
pytest.param(
55+
"state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated")
56+
),
57+
# ("Deci/DeciLM-7B", 1, False, False, False),
58+
("Qwen/Qwen2-7B", 256, False, True, False),
59+
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False, False),
60+
# ("EleutherAI/gpt-neo-2.7B", 1, False, False, False),
61+
# ("facebook/xglm-1.7B", 1, False, False, False),
62+
# ("CohereForAI/c4ai-command-r-v01", 1, False, False, False),
63+
("tiiuae/falcon-mamba-7b", 1, False, False, False),
64+
("openbmb/MiniCPM3-4B", 1, False, False, False),
65+
("baichuan-inc/Baichuan2-7B-Chat", 1, True, False, False),
66+
("baichuan-inc/Baichuan2-13B-Chat", 1, False, False, False),
67+
("deepseek-ai/DeepSeek-V2-Lite", 1, False, False, False),
68+
("THUDM/chatglm2-6b", 1, True, False, False),
69+
("THUDM/chatglm3-6b", 1, True, False, False),
70+
("Qwen/Qwen2.5-7B", 4, False, False, False),
71+
("moonshotai/Moonlight-16B-A3B", 1, False, False, False),
72+
("Qwen/Qwen3-8B", 1, False, False, False),
73+
("Qwen/Qwen3-30B-A3B", 1, False, False, False),
6774
],
6875
"fp8": [
69-
pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, marks=pytest.mark.x4),
70-
("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128),
71-
("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048),
72-
("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128),
73-
("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048),
74-
pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, marks=pytest.mark.x4),
75-
pytest.param("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, marks=pytest.mark.x4),
76-
pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, marks=pytest.mark.x4),
77-
pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, marks=pytest.mark.x8),
78-
("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128),
79-
# ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048),
80-
# ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128),
81-
("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048),
82-
("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128),
83-
pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 768, True, 128, 128, marks=pytest.mark.x2),
84-
# pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, marks=pytest.mark.x2),
85-
# pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, marks=pytest.mark.x2),
86-
pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, marks=pytest.mark.x2),
87-
("microsoft/phi-2", 1, 1, True, 128, 128),
76+
pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, False, marks=pytest.mark.x4),
77+
("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128, False),
78+
("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128, True),
79+
("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048, False),
80+
("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048, True),
81+
("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128, False),
82+
("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128, True),
83+
("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048, False),
84+
("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048, True),
85+
pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, False, marks=pytest.mark.x4),
86+
pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, True, marks=pytest.mark.x4),
87+
pytest.param("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, True, marks=pytest.mark.x4),
88+
pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, False, marks=pytest.mark.x4),
89+
pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, True, marks=pytest.mark.x4),
90+
pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, False, marks=pytest.mark.x8),
91+
pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, True, marks=pytest.mark.x8),
92+
("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128, False),
93+
# ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, False),
94+
# ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, False),
95+
("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048, False),
96+
("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128, False),
97+
pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 768, True, 128, 128, False, marks=pytest.mark.x2),
98+
# pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, False, marks=pytest.mark.x2),
99+
# pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, False, marks=pytest.mark.x2),
100+
pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, False, marks=pytest.mark.x2),
101+
("microsoft/phi-2", 1, 1, True, 128, 128, False),
88102
],
89103
"load_quantized_model_with_autogptq": [
90104
("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048),
@@ -121,24 +135,25 @@
121135
# Gaudi1
122136
MODELS_TO_TEST = {
123137
"bf16_1x": [
124-
("bigscience/bloomz-7b1", 1, False, False),
125-
("gpt2-xl", 1, False, False),
138+
("bigscience/bloomz-7b1", 1, False, False, False),
139+
("gpt2-xl", 1, False, False, False),
126140
# TODO: fix OPT 6.7B
127141
# ("facebook/opt-6.7b", 0.0),
128-
("EleutherAI/gpt-j-6b", 1, True, False),
129-
("meta-llama/Llama-2-7b-hf", 1, True, False),
130-
("tiiuae/falcon-7b", 1, True, False),
131-
("bigcode/starcoder", 1, False, False),
132-
("Salesforce/codegen2-1B", 1, False, False),
133-
("mosaicml/mpt-7b", 1, False, False),
134-
("mistralai/Mistral-7B-v0.1", 1, True, False),
135-
("microsoft/phi-2", 1, False, False),
136-
("google/gemma-7b", 1, False, False),
137-
("stabilityai/stablelm-2-12b", 1, False, False),
138-
("Qwen/Qwen1.5-7B", 1, False, False),
139-
("adept/persimmon-8b-base", 1, False, False),
140-
("bigcode/starcoder2-3b", 1, False, False),
141-
("state-spaces/mamba-130m-hf", 224, False, False),
142+
("EleutherAI/gpt-j-6b", 1, True, False, False),
143+
("meta-llama/Llama-2-7b-hf", 1, True, False, False),
144+
("meta-llama/Llama-2-7b-hf", 1, True, False, True),
145+
("tiiuae/falcon-7b", 1, True, False, False),
146+
("bigcode/starcoder", 1, False, False, False),
147+
("Salesforce/codegen2-1B", 1, False, False, False),
148+
("mosaicml/mpt-7b", 1, False, False, False),
149+
("mistralai/Mistral-7B-v0.1", 1, True, False, False),
150+
("microsoft/phi-2", 1, False, False, False),
151+
("google/gemma-7b", 1, False, False, False),
152+
("stabilityai/stablelm-2-12b", 1, False, False, False),
153+
("Qwen/Qwen1.5-7B", 1, False, False, False),
154+
("adept/persimmon-8b-base", 1, False, False, False),
155+
("bigcode/starcoder2-3b", 1, False, False, False),
156+
("state-spaces/mamba-130m-hf", 224, False, False, False),
142157
],
143158
"fp8": [],
144159
"load_quantized_model_with_autogptq": [],
@@ -175,6 +190,7 @@ def _test_text_generation(
175190
num_beams: int = 1,
176191
num_return_sequences: int = 1,
177192
check_output: bool = False,
193+
use_flex_attention: bool = False,
178194
):
179195
command = ["python3"]
180196
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
@@ -237,8 +253,11 @@ def _test_text_generation(
237253
if torch_compile:
238254
command += ["--torch_compile"]
239255
if parallel_strategy == "tp":
240-
command += ["--use_flash_attention"]
241-
command += ["--flash_attention_recompute"]
256+
if use_flex_attention:
257+
command += ["--use_flex_attention"]
258+
else:
259+
command += ["--use_flash_attention"]
260+
command += ["--flash_attention_recompute"]
242261
env_variables["PT_ENABLE_INT64_SUPPORT"] = "1"
243262
env_variables["PT_HPU_LAZY_MODE"] = "0"
244263
else:
@@ -268,8 +287,11 @@ def _test_text_generation(
268287
if "--trim_logits" not in command:
269288
command += ["--trim_logits"]
270289
if "Llama-2" in model_name:
271-
command.insert(-2, "--use_flash_attention")
272-
command.insert(-2, "--flash_attention_recompute")
290+
if use_flex_attention:
291+
command.insert(-2, "--use_flex_attention")
292+
else:
293+
command.insert(-2, "--use_flash_attention")
294+
command.insert(-2, "--flash_attention_recompute")
273295
command.insert(-2, "--bucket_size 128")
274296
command.insert(-2, "--bucket_internal")
275297
if "Mistral" in model_name:
@@ -394,9 +416,11 @@ def _test_text_generation(
394416
)
395417

396418

397-
@pytest.mark.parametrize("model_name, batch_size, reuse_cache, check_output", MODELS_TO_TEST["bf16_1x"])
419+
@pytest.mark.parametrize(
420+
"model_name, batch_size, reuse_cache, check_output, use_flex_attention", MODELS_TO_TEST["bf16_1x"]
421+
)
398422
def test_text_generation_bf16_1x(
399-
model_name: str, batch_size: int, reuse_cache: bool, check_output: bool, baseline, token
423+
model_name: str, batch_size: int, reuse_cache: bool, check_output: bool, use_flex_attention: bool, baseline, token
400424
):
401425
_test_text_generation(
402426
model_name=model_name,
@@ -405,12 +429,13 @@ def test_text_generation_bf16_1x(
405429
batch_size=batch_size,
406430
reuse_cache=reuse_cache,
407431
check_output=check_output,
432+
use_flex_attention=use_flex_attention,
408433
)
409434

410435

411436
@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}")
412437
@pytest.mark.parametrize(
413-
"model_name, world_size, batch_size, reuse_cache, input_len, output_len", MODELS_TO_TEST["fp8"]
438+
"model_name, world_size, batch_size, reuse_cache, input_len, output_len, use_flex_attention", MODELS_TO_TEST["fp8"]
414439
)
415440
def test_text_generation_fp8(
416441
model_name: str,
@@ -419,6 +444,7 @@ def test_text_generation_fp8(
419444
reuse_cache: bool,
420445
input_len: int,
421446
output_len: int,
447+
use_flex_attention: bool,
422448
baseline,
423449
token,
424450
):
@@ -434,6 +460,7 @@ def test_text_generation_fp8(
434460
reuse_cache=reuse_cache,
435461
max_input_tokens=input_len,
436462
max_output_tokens=output_len,
463+
use_flex_attention=use_flex_attention,
437464
)
438465

439466

0 commit comments

Comments
 (0)