2323 # Gaudi2+
2424 MODELS_TO_TEST = {
2525 "bf16_1x" : [
26- ("bigscience/bloomz-7b1" , 1 , False , False ),
27- ("gpt2-xl" , 1 , False , False ),
28- pytest .param ("EleutherAI/gpt-j-6b" , 1 , False , False , marks = pytest .mark .skip ("Deprecated in v1.20" )),
29- ("EleutherAI/gpt-neox-20b" , 1 , False , False ),
30- ("meta-llama/Llama-2-7b-hf" , 1 , True , True ),
31- ("tiiuae/falcon-40b" , 1 , True , False ),
32- ("bigcode/starcoder" , 256 , True , True ),
33- pytest .param ("Salesforce/codegen2-1B" , 1 , False , False , marks = pytest .mark .skip ("Deprecated" )),
34- ("mosaicml/mpt-30b" , 1 , False , False ),
35- ("mistralai/Mistral-7B-v0.1" , 1 , True , True ),
36- ("mistralai/Mixtral-8x7B-v0.1" , 1 , False , True ),
37- ("microsoft/phi-2" , 1 , False , False ),
38- ("meta-llama/Meta-Llama-3-8B" , 1 , True , False ),
39- ("meta-llama/Llama-2-7b-hf" , 512 , True , False ),
40- ("meta-llama/Llama-2-7b-hf" , 512 , False , False ), # in some cases like TGI, reuse_cache isn't used
41- ("stabilityai/stablelm-2-12b" , 1 , False , False ),
42- ("codellama/CodeLlama-34b-hf" , 1 , True , False ),
43- ("bigcode/starcoder2-3b" , 1 , False , True ),
44- ("adept/persimmon-8b-base" , 4 , False , False ),
45- # ("Qwen/Qwen1.5-7B", 4, False, False),
46- ("google/gemma-7b" , 1 , False , True ),
47- ("google/gemma-2-9b" , 1 , False , True ),
48- ("google/gemma-2-27b" , 1 , False , True ),
49- pytest .param ("state-spaces/mamba-130m-hf" , 1536 , False , False , marks = pytest .mark .skip ("Deprecated" )),
50- # ("Deci/DeciLM-7B", 1, False, False),
51- ("Qwen/Qwen2-7B" , 256 , False , True ),
52- ("Qwen/Qwen1.5-MoE-A2.7B" , 1 , True , False ),
53- # ("EleutherAI/gpt-neo-2.7B", 1, False, False),
54- # ("facebook/xglm-1.7B", 1, False, False),
55- # ("CohereForAI/c4ai-command-r-v01", 1, False, False),
56- ("tiiuae/falcon-mamba-7b" , 1 , False , False ),
57- ("openbmb/MiniCPM3-4B" , 1 , False , False ),
58- ("baichuan-inc/Baichuan2-7B-Chat" , 1 , True , False ),
59- ("baichuan-inc/Baichuan2-13B-Chat" , 1 , False , False ),
60- ("deepseek-ai/DeepSeek-V2-Lite" , 1 , False , False ),
61- ("THUDM/chatglm2-6b" , 1 , True , False ),
62- ("THUDM/chatglm3-6b" , 1 , True , False ),
63- ("Qwen/Qwen2.5-7B" , 4 , False , False ),
64- ("moonshotai/Moonlight-16B-A3B" , 1 , False , False ),
65- ("Qwen/Qwen3-8B" , 1 , False , False ),
66- ("Qwen/Qwen3-30B-A3B" , 1 , False , False ),
26+ ("bigscience/bloomz-7b1" , 1 , False , False , False ),
27+ ("gpt2-xl" , 1 , False , False , False ),
28+ pytest .param ("EleutherAI/gpt-j-6b" , 1 , False , False , False , marks = pytest .mark .skip ("Deprecated in v1.20" )),
29+ ("EleutherAI/gpt-neox-20b" , 1 , False , False , False ),
30+ ("meta-llama/Llama-2-7b-hf" , 1 , True , True , False ),
31+ ("meta-llama/Llama-2-7b-hf" , 1 , True , True , True ),
32+ ("tiiuae/falcon-40b" , 1 , True , False , False ),
33+ ("bigcode/starcoder" , 256 , True , True , False ),
34+ pytest .param ("Salesforce/codegen2-1B" , 1 , False , False , False , marks = pytest .mark .skip ("Deprecated" )),
35+ ("mosaicml/mpt-30b" , 1 , False , False , False ),
36+ ("mistralai/Mistral-7B-v0.1" , 1 , True , True , False ),
37+ ("mistralai/Mixtral-8x7B-v0.1" , 1 , False , True , False ),
38+ ("microsoft/phi-2" , 1 , False , False , False ),
39+ ("meta-llama/Meta-Llama-3-8B" , 1 , True , False , False ),
40+ ("meta-llama/Meta-Llama-3-8B" , 1 , True , False , True ),
41+ ("meta-llama/Llama-2-7b-hf" , 512 , True , False , False ),
42+ ("meta-llama/Llama-2-7b-hf" , 512 , True , False , True ),
43+ ("meta-llama/Llama-2-7b-hf" , 512 , False , False , False ), # in some cases like TGI, reuse_cache isn't used
44+ ("meta-llama/Llama-2-7b-hf" , 512 , False , False , True ), # in some cases like TGI, reuse_cache isn't used
45+ ("stabilityai/stablelm-2-12b" , 1 , False , False , False ),
46+ ("codellama/CodeLlama-34b-hf" , 1 , True , False , False ),
47+ ("codellama/CodeLlama-34b-hf" , 1 , True , False , True ),
48+ ("bigcode/starcoder2-3b" , 1 , False , True , False ),
49+ ("adept/persimmon-8b-base" , 4 , False , False , False ),
50+ # ("Qwen/Qwen1.5-7B", 4, False, False, False),
51+ ("google/gemma-7b" , 1 , False , True , False ),
52+ ("google/gemma-2-9b" , 1 , False , True , False ),
53+ ("google/gemma-2-27b" , 1 , False , True , False ),
54+ pytest .param (
55+ "state-spaces/mamba-130m-hf" , 1536 , False , False , False , marks = pytest .mark .skip ("Deprecated" )
56+ ),
57+ # ("Deci/DeciLM-7B", 1, False, False, False),
58+ ("Qwen/Qwen2-7B" , 256 , False , True , False ),
59+ ("Qwen/Qwen1.5-MoE-A2.7B" , 1 , True , False , False ),
60+ # ("EleutherAI/gpt-neo-2.7B", 1, False, False, False),
61+ # ("facebook/xglm-1.7B", 1, False, False, False),
62+ # ("CohereForAI/c4ai-command-r-v01", 1, False, False, False),
63+ ("tiiuae/falcon-mamba-7b" , 1 , False , False , False ),
64+ ("openbmb/MiniCPM3-4B" , 1 , False , False , False ),
65+ ("baichuan-inc/Baichuan2-7B-Chat" , 1 , True , False , False ),
66+ ("baichuan-inc/Baichuan2-13B-Chat" , 1 , False , False , False ),
67+ ("deepseek-ai/DeepSeek-V2-Lite" , 1 , False , False , False ),
68+ ("THUDM/chatglm2-6b" , 1 , True , False , False ),
69+ ("THUDM/chatglm3-6b" , 1 , True , False , False ),
70+ ("Qwen/Qwen2.5-7B" , 4 , False , False , False ),
71+ ("moonshotai/Moonlight-16B-A3B" , 1 , False , False , False ),
72+ ("Qwen/Qwen3-8B" , 1 , False , False , False ),
73+ ("Qwen/Qwen3-30B-A3B" , 1 , False , False , False ),
6774 ],
6875 "fp8" : [
69- pytest .param ("tiiuae/falcon-180B" , 4 , 950 , True , 128 , 128 , marks = pytest .mark .x4 ),
70- ("meta-llama/Llama-2-7b-hf" , 1 , 1230 , False , 128 , 128 ),
71- ("meta-llama/Llama-2-7b-hf" , 1 , 163 , False , 128 , 2048 ),
72- ("meta-llama/Llama-2-7b-hf" , 1 , 94 , False , 2048 , 128 ),
73- ("meta-llama/Llama-2-7b-hf" , 1 , 81 , False , 2048 , 2048 ),
74- pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 3042 , False , 128 , 128 , marks = pytest .mark .x4 ),
75- pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 750 , False , 128 , 2048 , marks = pytest .mark .x4 ),
76- pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 207 , False , 2048 , 128 , marks = pytest .mark .x4 ),
77- pytest .param ("meta-llama/Llama-2-70b-hf" , 8 , 172 , False , 2048 , 2048 , marks = pytest .mark .x8 ),
78- ("mistralai/Mistral-7B-Instruct-v0.2" , 1 , 896 , True , 128 , 128 ),
79- # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048),
80- # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128),
81- ("mistralai/Mistral-7B-Instruct-v0.2" , 1 , 44 , True , 2048 , 2048 ),
82- ("mistralai/Mixtral-8x7B-v0.1" , 1 , 1 , True , 128 , 128 ),
83- pytest .param ("mistralai/Mixtral-8x7B-v0.1" , 2 , 768 , True , 128 , 128 , marks = pytest .mark .x2 ),
84- # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, marks=pytest.mark.x2),
85- # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, marks=pytest.mark.x2),
86- pytest .param ("mistralai/Mixtral-8x7B-v0.1" , 2 , 48 , True , 2048 , 2048 , marks = pytest .mark .x2 ),
87- ("microsoft/phi-2" , 1 , 1 , True , 128 , 128 ),
76+ pytest .param ("tiiuae/falcon-180B" , 4 , 950 , True , 128 , 128 , False , marks = pytest .mark .x4 ),
77+ ("meta-llama/Llama-2-7b-hf" , 1 , 1230 , False , 128 , 128 , False ),
78+ ("meta-llama/Llama-2-7b-hf" , 1 , 1230 , False , 128 , 128 , True ),
79+ ("meta-llama/Llama-2-7b-hf" , 1 , 163 , False , 128 , 2048 , False ),
80+ ("meta-llama/Llama-2-7b-hf" , 1 , 163 , False , 128 , 2048 , True ),
81+ ("meta-llama/Llama-2-7b-hf" , 1 , 94 , False , 2048 , 128 , False ),
82+ ("meta-llama/Llama-2-7b-hf" , 1 , 94 , False , 2048 , 128 , True ),
83+ ("meta-llama/Llama-2-7b-hf" , 1 , 81 , False , 2048 , 2048 , False ),
84+ ("meta-llama/Llama-2-7b-hf" , 1 , 81 , False , 2048 , 2048 , True ),
85+ pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 3042 , False , 128 , 128 , False , marks = pytest .mark .x4 ),
86+ pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 3042 , False , 128 , 128 , True , marks = pytest .mark .x4 ),
87+ pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 750 , False , 128 , 2048 , True , marks = pytest .mark .x4 ),
88+ pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 207 , False , 2048 , 128 , False , marks = pytest .mark .x4 ),
89+ pytest .param ("meta-llama/Llama-2-70b-hf" , 4 , 207 , False , 2048 , 128 , True , marks = pytest .mark .x4 ),
90+ pytest .param ("meta-llama/Llama-2-70b-hf" , 8 , 172 , False , 2048 , 2048 , False , marks = pytest .mark .x8 ),
91+ pytest .param ("meta-llama/Llama-2-70b-hf" , 8 , 172 , False , 2048 , 2048 , True , marks = pytest .mark .x8 ),
92+ ("mistralai/Mistral-7B-Instruct-v0.2" , 1 , 896 , True , 128 , 128 , False ),
93+ # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, False),
94+ # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, False),
95+ ("mistralai/Mistral-7B-Instruct-v0.2" , 1 , 44 , True , 2048 , 2048 , False ),
96+ ("mistralai/Mixtral-8x7B-v0.1" , 1 , 1 , True , 128 , 128 , False ),
97+ pytest .param ("mistralai/Mixtral-8x7B-v0.1" , 2 , 768 , True , 128 , 128 , False , marks = pytest .mark .x2 ),
98+ # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, False, marks=pytest.mark.x2),
99+ # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, False, marks=pytest.mark.x2),
100+ pytest .param ("mistralai/Mixtral-8x7B-v0.1" , 2 , 48 , True , 2048 , 2048 , False , marks = pytest .mark .x2 ),
101+ ("microsoft/phi-2" , 1 , 1 , True , 128 , 128 , False ),
88102 ],
89103 "load_quantized_model_with_autogptq" : [
90104 ("TheBloke/Llama-2-7b-Chat-GPTQ" , 1 , 10 , False , 128 , 2048 ),
121135 # Gaudi1
122136 MODELS_TO_TEST = {
123137 "bf16_1x" : [
124- ("bigscience/bloomz-7b1" , 1 , False , False ),
125- ("gpt2-xl" , 1 , False , False ),
138+ ("bigscience/bloomz-7b1" , 1 , False , False , False ),
139+ ("gpt2-xl" , 1 , False , False , False ),
126140 # TODO: fix OPT 6.7B
127141 # ("facebook/opt-6.7b", 0.0),
128- ("EleutherAI/gpt-j-6b" , 1 , True , False ),
129- ("meta-llama/Llama-2-7b-hf" , 1 , True , False ),
130- ("tiiuae/falcon-7b" , 1 , True , False ),
131- ("bigcode/starcoder" , 1 , False , False ),
132- ("Salesforce/codegen2-1B" , 1 , False , False ),
133- ("mosaicml/mpt-7b" , 1 , False , False ),
134- ("mistralai/Mistral-7B-v0.1" , 1 , True , False ),
135- ("microsoft/phi-2" , 1 , False , False ),
136- ("google/gemma-7b" , 1 , False , False ),
137- ("stabilityai/stablelm-2-12b" , 1 , False , False ),
138- ("Qwen/Qwen1.5-7B" , 1 , False , False ),
139- ("adept/persimmon-8b-base" , 1 , False , False ),
140- ("bigcode/starcoder2-3b" , 1 , False , False ),
141- ("state-spaces/mamba-130m-hf" , 224 , False , False ),
142+ ("EleutherAI/gpt-j-6b" , 1 , True , False , False ),
143+ ("meta-llama/Llama-2-7b-hf" , 1 , True , False , False ),
144+ ("meta-llama/Llama-2-7b-hf" , 1 , True , False , True ),
145+ ("tiiuae/falcon-7b" , 1 , True , False , False ),
146+ ("bigcode/starcoder" , 1 , False , False , False ),
147+ ("Salesforce/codegen2-1B" , 1 , False , False , False ),
148+ ("mosaicml/mpt-7b" , 1 , False , False , False ),
149+ ("mistralai/Mistral-7B-v0.1" , 1 , True , False , False ),
150+ ("microsoft/phi-2" , 1 , False , False , False ),
151+ ("google/gemma-7b" , 1 , False , False , False ),
152+ ("stabilityai/stablelm-2-12b" , 1 , False , False , False ),
153+ ("Qwen/Qwen1.5-7B" , 1 , False , False , False ),
154+ ("adept/persimmon-8b-base" , 1 , False , False , False ),
155+ ("bigcode/starcoder2-3b" , 1 , False , False , False ),
156+ ("state-spaces/mamba-130m-hf" , 224 , False , False , False ),
142157 ],
143158 "fp8" : [],
144159 "load_quantized_model_with_autogptq" : [],
@@ -175,6 +190,7 @@ def _test_text_generation(
175190 num_beams : int = 1 ,
176191 num_return_sequences : int = 1 ,
177192 check_output : bool = False ,
193+ use_flex_attention : bool = False ,
178194):
179195 command = ["python3" ]
180196 path_to_example_dir = Path (__file__ ).resolve ().parent .parent / "examples"
@@ -237,8 +253,11 @@ def _test_text_generation(
237253 if torch_compile :
238254 command += ["--torch_compile" ]
239255 if parallel_strategy == "tp" :
240- command += ["--use_flash_attention" ]
241- command += ["--flash_attention_recompute" ]
256+ if use_flex_attention :
257+ command += ["--use_flex_attention" ]
258+ else :
259+ command += ["--use_flash_attention" ]
260+ command += ["--flash_attention_recompute" ]
242261 env_variables ["PT_ENABLE_INT64_SUPPORT" ] = "1"
243262 env_variables ["PT_HPU_LAZY_MODE" ] = "0"
244263 else :
@@ -268,8 +287,11 @@ def _test_text_generation(
268287 if "--trim_logits" not in command :
269288 command += ["--trim_logits" ]
270289 if "Llama-2" in model_name :
271- command .insert (- 2 , "--use_flash_attention" )
272- command .insert (- 2 , "--flash_attention_recompute" )
290+ if use_flex_attention :
291+ command .insert (- 2 , "--use_flex_attention" )
292+ else :
293+ command .insert (- 2 , "--use_flash_attention" )
294+ command .insert (- 2 , "--flash_attention_recompute" )
273295 command .insert (- 2 , "--bucket_size 128" )
274296 command .insert (- 2 , "--bucket_internal" )
275297 if "Mistral" in model_name :
@@ -394,9 +416,11 @@ def _test_text_generation(
394416 )
395417
396418
397- @pytest .mark .parametrize ("model_name, batch_size, reuse_cache, check_output" , MODELS_TO_TEST ["bf16_1x" ])
419+ @pytest .mark .parametrize (
420+ "model_name, batch_size, reuse_cache, check_output, use_flex_attention" , MODELS_TO_TEST ["bf16_1x" ]
421+ )
398422def test_text_generation_bf16_1x (
399- model_name : str , batch_size : int , reuse_cache : bool , check_output : bool , baseline , token
423+ model_name : str , batch_size : int , reuse_cache : bool , check_output : bool , use_flex_attention : bool , baseline , token
400424):
401425 _test_text_generation (
402426 model_name = model_name ,
@@ -405,12 +429,13 @@ def test_text_generation_bf16_1x(
405429 batch_size = batch_size ,
406430 reuse_cache = reuse_cache ,
407431 check_output = check_output ,
432+ use_flex_attention = use_flex_attention ,
408433 )
409434
410435
411436@pytest .mark .skipif (condition = bool ("gaudi1" == OH_DEVICE_CONTEXT ), reason = f"Skipping test for { OH_DEVICE_CONTEXT } " )
412437@pytest .mark .parametrize (
413- "model_name, world_size, batch_size, reuse_cache, input_len, output_len" , MODELS_TO_TEST ["fp8" ]
438+ "model_name, world_size, batch_size, reuse_cache, input_len, output_len, use_flex_attention " , MODELS_TO_TEST ["fp8" ]
414439)
415440def test_text_generation_fp8 (
416441 model_name : str ,
@@ -419,6 +444,7 @@ def test_text_generation_fp8(
419444 reuse_cache : bool ,
420445 input_len : int ,
421446 output_len : int ,
447+ use_flex_attention : bool ,
422448 baseline ,
423449 token ,
424450):
@@ -434,6 +460,7 @@ def test_text_generation_fp8(
434460 reuse_cache = reuse_cache ,
435461 max_input_tokens = input_len ,
436462 max_output_tokens = output_len ,
463+ use_flex_attention = use_flex_attention ,
437464 )
438465
439466
0 commit comments