diff --git a/skyrl/backends/jax.py b/skyrl/backends/jax.py index bf2f1fd297..ae8f1c18ab 100644 --- a/skyrl/backends/jax.py +++ b/skyrl/backends/jax.py @@ -83,7 +83,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): ) gradient_checkpointing: bool = Field( default=False, - description="Whether to use gradient checkpointing (full recomputation strategy)", + description="Per-layer activation checkpointing: recompute activations during backward to save memory", ) loss_chunk_size: int = Field( default=1024, @@ -275,14 +275,10 @@ def _model_forward( input_ids, attention_mask=attention_mask, adapter_indices=adapter_indices, + is_training=True, ) return model.compute_logprobs(output.last_hidden_state, target_ids, adapter_indices) - if self.config.gradient_checkpointing: - # Wrap the model forward call to use jax.checkpoint for gradient checkpointing - # policy=None corresponds to full activation recomputation - _model_forward = jax.checkpoint(_model_forward, policy=None) - def loss_for_lora( lora_params: nnx.State, non_lora_params: nnx.State, diff --git a/skyrl/tx/models/configs.py b/skyrl/tx/models/configs.py index 7dcd7ec7b1..db8f21717f 100644 --- a/skyrl/tx/models/configs.py +++ b/skyrl/tx/models/configs.py @@ -15,7 +15,7 @@ class ModelConfig(PretrainedConfig): max_lora_rank: Maximum rank for LoRA adapters shard_attention_heads: Whether to shard attention across tensor parallel devices loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking) - gradient_checkpointing: Whether to use gradient checkpointing for chunked loss + gradient_checkpointing: Recompute activations during backward to save memory """ # Type hints for config attributes diff --git a/tests/tx/models/conftest.py b/tests/tx/models/conftest.py index 6560e4e393..372d3912e1 100644 --- a/tests/tx/models/conftest.py +++ b/tests/tx/models/conftest.py @@ -1,3 +1,5 @@ +from typing import Any + from flax import nnx import jax import jax.numpy as jnp @@ -8,23 +10,52 @@ from skyrl.tx.utils.models import load_safetensors, resolve_model_path -def load_model( +def create_model( model_name: str, config_cls: type[ModelConfig], model_cls: type[ModelForCausalLM], mesh_axes: tuple[str, ...], *, mesh_shape: tuple[int, ...] | None = None, - **config_kwargs, + mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, + seed: int = 0, + **config_kwargs: Any, ) -> tuple[ModelConfig, ModelForCausalLM]: - """Create a JAX model and load weights from the HuggingFace cache.""" - weights_dir = resolve_model_path(model_name) + """Create a JAX model with initialized weights.""" base_config = AutoConfig.from_pretrained(model_name) config = config_cls(base_config, shard_attention_heads=True, **config_kwargs) if mesh_shape is None: mesh_shape = (1,) * len(mesh_axes) - mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes)) + if mesh_axis_types is None: + mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes) + mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types) with jax.set_mesh(mesh): - model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed)) + return config, model + + +def load_model( + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, ...], + *, + mesh_shape: tuple[int, ...] | None = None, + mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, + seed: int = 0, + **config_kwargs: Any, +) -> tuple[ModelConfig, ModelForCausalLM]: + """Create a JAX model and load weights from the HuggingFace cache.""" + config, model = create_model( + model_name, + config_cls, + model_cls, + mesh_axes, + mesh_shape=mesh_shape, + mesh_axis_types=mesh_axis_types, + seed=seed, + **config_kwargs, + ) + weights_dir = resolve_model_path(model_name) load_safetensors(weights_dir, config, model) return config, model diff --git a/tests/tx/models/test_deepseekv3.py b/tests/tx/models/test_deepseekv3.py index ada8fd6613..27e0040910 100644 --- a/tests/tx/models/test_deepseekv3.py +++ b/tests/tx/models/test_deepseekv3.py @@ -186,3 +186,51 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int): output_merged = moe_layer_merged(x_sample) assert np.allclose(output_with_lora[sample_idx : sample_idx + 1], output_merged, rtol=1e-3, atol=1e-3) + + +def test_deepseekv3_gradient_checkpointing(): + """Test that gradient checkpointing produces identical outputs for DeepSeekV3. + + DeepSeekV3 has split stacking (dense_layers + moe_layers), so this tests + that gradient checkpointing works correctly with heterogeneous layer types. + """ + model_name = "yujiepan/deepseek-v3-tiny-random" + base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) + + batch_size, seq_len = 2, 8 + mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + + results = {} + for use_checkpointing in [False, True]: + config = DeepseekV3Config( + base_config, + max_lora_adapters=1, + max_lora_rank=1, + shard_attention_heads=True, + gradient_checkpointing=use_checkpointing, + ) + with jax.set_mesh(mesh): + model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + + input_ids = jax.random.randint(jax.random.key(42), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + out = model(input_ids, attention_mask=attention_mask, output_hidden_states=True) + logits = model.compute_logits(out.last_hidden_state) + + results[use_checkpointing] = { + "logits": np.array(logits), + "hidden_states": [np.array(hs) for hs in out.hidden_states], + "kv_cache_len": len(out.kv_cache.keys), + } + + # Verify outputs match + np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6) + + # Verify hidden states match + assert len(results[False]["hidden_states"]) == len(results[True]["hidden_states"]) + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(results[False]["hidden_states"], results[True]["hidden_states"])): + np.testing.assert_allclose(hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}") + + # Verify KV cache has correct number of layers + assert results[True]["kv_cache_len"] == config.num_hidden_layers diff --git a/tests/tx/models/test_models_common.py b/tests/tx/models/test_models_common.py index df6cba5af4..f980468584 100644 --- a/tests/tx/models/test_models_common.py +++ b/tests/tx/models/test_models_common.py @@ -1,3 +1,6 @@ +from typing import Any + +import jax import jax.numpy as jnp import numpy as np import pytest @@ -6,9 +9,9 @@ from skyrl.tx.models.configs import Llama3Config, ModelConfig, Qwen3Config from skyrl.tx.models.llama3 import Llama3ForCausalLM from skyrl.tx.models.qwen3 import Qwen3ForCausalLM -from skyrl.tx.models.types import ModelForCausalLM +from skyrl.tx.models.types import CausalLMOutput, ModelForCausalLM -from tests.tx.models.conftest import load_model +from tests.tx.models.conftest import create_model, load_model MODEL_PARAMS = [ ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")), @@ -17,6 +20,85 @@ MODEL_IDS = ["llama3", "qwen3"] +@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) +class TestGradientCheckpointing: + + def _forward( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + gradient_checkpointing: bool, + **forward_kwargs: Any, + ) -> tuple[ModelForCausalLM, ModelConfig, CausalLMOutput]: + """Create model, run forward pass, and return (model, config, out).""" + batch_size, seq_len = 2, 8 + config, model = create_model( + model_name, + config_cls, + model_cls, + mesh_axes, + max_lora_adapters=1, + max_lora_rank=1, + gradient_checkpointing=gradient_checkpointing, + ) + input_ids = jax.random.randint(jax.random.key(0), (batch_size, seq_len), 0, config.vocab_size) + attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + out = model(input_ids, attention_mask=attention_mask, **forward_kwargs) + return model, config, out + + def test_output_and_hidden_states_match( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: + """Forward pass should produce identical outputs and hidden states with/without checkpointing.""" + results = {} + for use_checkpointing in (False, True): + model, config, out = self._forward( + model_name, + config_cls, + model_cls, + mesh_axes, + gradient_checkpointing=use_checkpointing, + output_hidden_states=True, + ) + results[use_checkpointing] = { + "logits": np.asarray(model.compute_logits(out.last_hidden_state)), + "hidden_states": [np.asarray(hs) for hs in out.hidden_states], + "num_hidden_layers": config.num_hidden_layers, + } + del model, config, out + + np.testing.assert_allclose(results[False]["logits"], results[True]["logits"], rtol=1e-4, atol=1e-6) + + hidden_states_no_ckpt = results[False]["hidden_states"] + hidden_states_ckpt = results[True]["hidden_states"] + assert len(hidden_states_no_ckpt) == len(hidden_states_ckpt) == results[False]["num_hidden_layers"] + 1 + for i, (hs_no_ckpt, hs_ckpt) in enumerate(zip(hidden_states_no_ckpt, hidden_states_ckpt)): + np.testing.assert_allclose( + hs_no_ckpt, hs_ckpt, rtol=1e-4, atol=1e-6, err_msg=f"Mismatch at hidden state {i}" + ) + + def test_kv_cache_with_checkpointing( + self, + model_name: str, + config_cls: type[ModelConfig], + model_cls: type[ModelForCausalLM], + mesh_axes: tuple[str, str], + ) -> None: + """KV cache should be populated even with gradient checkpointing enabled.""" + _, config, out = self._forward( + model_name, config_cls, model_cls, mesh_axes, gradient_checkpointing=True + ) + + # keys is a list with one entry per layer + assert len(out.kv_cache.keys) == config.num_hidden_layers + + @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) def test_compute_logits( model_name: str,