-
Notifications
You must be signed in to change notification settings - Fork 286
Port https://github.com/NovaSky-AI/SkyRL/pull/1083 to skyrl folder #1162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int): | |
| output_merged = moe_layer_merged(x_sample) | ||
|
|
||
| assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) | ||
|
|
||
|
|
||
| def test_deepseekv3_gradient_checkpointing(): | ||
| """Test that gradient checkpointing produces identical outputs for DeepSeekV3. | ||
|
|
||
| DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests | ||
| that gradient checkpointing works correctly with heterogeneous layer types. | ||
| """ | ||
| model_name = "yujiepan/deepseek-v3-tiny-random" | ||
| base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) | ||
|
|
||
| batch_size, seq_len = 2, 8 | ||
| mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) | ||
|
|
||
| results = {} | ||
| for use_checkpointing in [False, True]: | ||
| config = DeepseekV3Config( | ||
| base_config, | ||
| max_lora_adapters=1, | ||
| max_lora_rank=1, | ||
| shard_attention_heads=True, | ||
| gradient_checkpointing=use_checkpointing, | ||
| ) | ||
| with jax.set_mesh(mesh): | ||
| model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) | ||
|
Comment on lines
+197
to
+213
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test duplicates model creation logic that is now available in the Assuming the suggested change to model_name = "yujiepan/deepseek-v3-tiny-random"
batch_size, seq_len = 2, 8
mesh_axes = ("fsdp", "ep", "tp")
results = {}
for use_checkpointing in [False, True]:
config, model = create_model(
model_name,
DeepseekV3Config,
DeepseekV3ForCausalLM,
mesh_axes,
mesh_shape=(1, 1, 1),
auto_config_kwargs={"trust_remote_code": True},
max_lora_adapters=1,
max_lora_rank=1,
gradient_checkpointing=use_checkpointing,
seed=0,
) |
||
|
|
||
| input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size) | ||
| attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) | ||
|
|
||
| out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) | ||
| logits = model.compute_logits(out.last_hidden_state) | ||
|
|
||
| results[use_checkpointing] = { | ||
| "logits": np.array(logits), | ||
| "hidden_states": [np.array(hs) for hs in out.hidden_states], | ||
| "kv_cache_len": len(out.kv_cache.keys), | ||
| } | ||
|
|
||
| # Verify outputs match | ||
| np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6) | ||
|
|
||
| # Verify hidden states match | ||
| assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"]) | ||
| for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])): | ||
| np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") | ||
|
|
||
| # Verify KV cache has correct number of layers | ||
| assert results[True]["kv_cache_len"] == config.num_hidden_layers | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new
create_modelhelper function is a great addition for simplifying test setup. However, it currently doesn't allow passing arguments toAutoConfig.from_pretrained, which is necessary for models that requiretrust_remote_code=True(as seen intest_deepseekv3.py).To make this helper more versatile and promote consistency across tests, consider adding an
auto_config_kwargsparameter to bothcreate_modelandload_model. This would allow all tests to use this common helper, avoiding duplicated model creation logic.