Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
)
gradient_checkpointing: bool = Field(
default=False,
description="Whether to use gradient checkpointing (full recomputation strategy)",
description="Per-layer activation checkpointing: recompute activations during backward to save memory",
)
loss_chunk_size: int = Field(
default=1024,
Expand Down
2 changes: 1 addition & 1 deletion skyrl/tx/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ModelConfig(PretrainedConfig):
max_lora_rank: Maximum rank for LoRA adapters
shard_attention_heads: Whether to shard attention across tensor parallel devices
loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking)
gradient_checkpointing: Whether to use gradient checkpointing for chunked loss
gradient_checkpointing: Recompute activations during backward to save memory
"""

# Type hints for config attributes
Expand Down
43 changes: 37 additions & 6 deletions tests/tx/models/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from flax import nnx
import jax
import jax.numpy as jnp
Expand All @@ -8,23 +10,52 @@
from skyrl.tx.utils.models import load_safetensors, resolve_model_path


def load_model(
def create_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
**config_kwargs,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
weights_dir = resolve_model_path(model_name)
"""Create a JAX model with initialized weights."""
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
if mesh_shape is None:
mesh_shape = (1,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes))
if mesh_axis_types is None:
mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
return config, model


def load_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
config, model = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
mesh_shape=mesh_shape,
mesh_axis_types=mesh_axis_types,
seed=seed,
**config_kwargs,
)
weights_dir = resolve_model_path(model_name)
load_safetensors(weights_dir, config, model)
return config, model
Comment on lines +13 to 61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new create_model helper function is a great addition for simplifying test setup. However, it currently doesn't allow passing arguments to AutoConfig.from_pretrained, which is necessary for models that require trust_remote_code=True (as seen in test_deepseekv3.py).

To make this helper more versatile and promote consistency across tests, consider adding an auto_config_kwargs parameter to both create_model and load_model. This would allow all tests to use this common helper, avoiding duplicated model creation logic.

def create_model(
    model_name: str,
    config_cls: type[ModelConfig],
    model_cls: type[ModelForCausalLM],
    mesh_axes: tuple[str, ...],
    *,
    mesh_shape: tuple[int, ...] | None = None,
    mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
    seed: int = 0,
    auto_config_kwargs: dict[str, Any] | None = None,
    **config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
    """Create a JAX model with initialized weights."""
    if auto_config_kwargs is None:
        auto_config_kwargs = {}
    base_config = AutoConfig.from_pretrained(model_name, **auto_config_kwargs)
    config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
    if mesh_shape is None:
        mesh_shape = (1,) * len(mesh_axes)
    if mesh_axis_types is None:
        mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
    mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types)
    with jax.set_mesh(mesh):
        model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
    return config, model


def load_model(
    model_name: str,
    config_cls: type[ModelConfig],
    model_cls: type[ModelForCausalLM],
    mesh_axes: tuple[str, ...],
    *,
    mesh_shape: tuple[int, ...] | None = None,
    mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
    seed: int = 0,
    auto_config_kwargs: dict[str, Any] | None = None,
    **config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
    """Create a JAX model and load weights from the HuggingFace cache."""
    config, model = create_model(
        model_name,
        config_cls,
        model_cls,
        mesh_axes,
        mesh_shape=mesh_shape,
        mesh_axis_types=mesh_axis_types,
        seed=seed,
        auto_config_kwargs=auto_config_kwargs,
        **config_kwargs,
    )
    weights_dir = resolve_model_path(model_name)
    load_safetensors(weights_dir, config, model)
    return config, model

48 changes: 48 additions & 0 deletions tests/tx/models/test_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int):
output_merged = moe_layer_merged(x_sample)

assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3)


def test_deepseekv3_gradient_checkpointing():
"""Test that gradient checkpointing produces identical outputs for DeepSeekV3.

DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests
that gradient checkpointing works correctly with heterogeneous layer types.
"""
model_name = "yujiepan/deepseek-v3-tiny-random"
base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True)

batch_size, seq_len = 2, 8
mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)

results = {}
for use_checkpointing in [False, True]:
config = DeepseekV3Config(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
gradient_checkpointing=use_checkpointing,
)
with jax.set_mesh(mesh):
model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
Comment on lines +197 to +213
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test duplicates model creation logic that is now available in the create_model helper function in conftest.py. To improve code reuse and consistency, this test should be refactored to use create_model.

Assuming the suggested change to create_model (to support auto_config_kwargs) is applied, this test can be simplified as follows:

    model_name = "yujiepan/deepseek-v3-tiny-random"

    batch_size, seq_len = 2, 8
    mesh_axes = ("fsdp", "ep", "tp")

    results = {}
    for use_checkpointing in [False, True]:
        config, model = create_model(
            model_name,
            DeepseekV3Config,
            DeepseekV3ForCausalLM,
            mesh_axes,
            mesh_shape=(1, 1, 1),
            auto_config_kwargs={"trust_remote_code": True},
            max_lora_adapters=1,
            max_lora_rank=1,
            gradient_checkpointing=use_checkpointing,
            seed=0,
        )


input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)

out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
logits = model.compute_logits(out.last_hidden_state)

results[use_checkpointing] = {
"logits": np.array(logits),
"hidden_states": [np.array(hs) for hs in out.hidden_states],
"kv_cache_len": len(out.kv_cache.keys),
}

# Verify outputs match
np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6)

# Verify hidden states match
assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"])
for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])):
np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}")

# Verify KV cache has correct number of layers
assert results[True]["kv_cache_len"] == config.num_hidden_layers
86 changes: 84 additions & 2 deletions tests/tx/models/test_models_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
import pytest
Expand All @@ -6,9 +9,9 @@
from skyrl.tx.models.configs import Llama3Config, ModelConfig, Qwen3Config
from skyrl.tx.models.llama3 import Llama3ForCausalLM
from skyrl.tx.models.qwen3 import Qwen3ForCausalLM
from skyrl.tx.models.types import ModelForCausalLM
from skyrl.tx.models.types import CausalLMOutput, ModelForCausalLM

from tests.tx.models.conftest import load_model
from tests.tx.models.conftest import create_model, load_model

MODEL_PARAMS = [
("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")),
Expand All @@ -17,6 +20,85 @@
MODEL_IDS = ["llama3", "qwen3"]


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
class TestGradientCheckpointing:

def _forward(
self,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
gradient_checkpointing: bool,
**forward_kwargs: Any,
) -> tuple[ModelForCausalLM, ModelConfig, CausalLMOutput]:
"""Create model, run forward pass, and return (model, config, out)."""
batch_size, seq_len = 2, 8
config, model = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
max_lora_adapters=1,
max_lora_rank=1,
gradient_checkpointing=gradient_checkpointing,
)
input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
out = model(input_ids, attention_mask=attention_mask, **forward_kwargs)
return model, config, out

def test_output_and_hidden_states_match(
self,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
) -> None:
"""Forward pass should produce identical outputs and hidden states with/without checkpointing."""
results = {}
for use_checkpointing in (False, True):
model, config, out = self._forward(
model_name,
config_cls,
model_cls,
mesh_axes,
gradient_checkpointing=use_checkpointing,
output_hidden_states=True,
)
results[use_checkpointing] = {
"logits": np.asarray(model.compute_logits(out.last_hidden_state)),
"hidden_states": [np.asarray(hs) for hs in out.hidden_states],
"num_hidden_layers": config.num_hidden_layers,
}
del model, config, out

np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6)

hidden_states_no_ckpt = results[False]["hidden_states"]
hidden_states_ckpt = results[True]["hidden_states"]
assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == results[False]["num_hidden_layers"] + 1
for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)):
np.testing.assert_allclose(
hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}"
)

def test_kv_cache_with_checkpointing(
self,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
) -> None:
"""KV cache should be populated even with gradient checkpointing enabled."""
_, config, out = self._forward(
model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True
)

# keys is a list with one entry per layer
assert len(out.kv_cache.keys) == config.num_hidden_layers


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
def test_compute_logits(
model_name: str,
Expand Down
Loading