Skip to content

Port https://github.com/NovaSky-AI/SkyRL/pull/1083 to skyrl folder#1162

Merged
pcmoritz merged 2 commits intoNovaSky-AI:mainfrom
pcmoritz:port-1083
Feb 16, 2026
Merged

Port https://github.com/NovaSky-AI/SkyRL/pull/1083 to skyrl folder#1162
pcmoritz merged 2 commits intoNovaSky-AI:mainfrom
pcmoritz:port-1083

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 16, 2026

See #1083


Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully ports changes to enable and test gradient checkpointing. The updates to configuration descriptions are clearer, and the new tests for gradient checkpointing are comprehensive, covering multiple models. The refactoring in conftest.py to introduce a create_model helper is a good improvement for test structure. I have provided a couple of suggestions to further enhance this new helper for better versatility and to refactor a new test to use it, which would improve consistency across the test suite. Overall, this is a high-quality contribution.

Comment on lines +13 to 61
def create_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
**config_kwargs,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
weights_dir = resolve_model_path(model_name)
"""Create a JAX model with initialized weights."""
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
if mesh_shape is None:
mesh_shape = (1,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes))
if mesh_axis_types is None:
mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
return config, model


def load_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
config, model = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
mesh_shape=mesh_shape,
mesh_axis_types=mesh_axis_types,
seed=seed,
**config_kwargs,
)
weights_dir = resolve_model_path(model_name)
load_safetensors(weights_dir, config, model)
return config, model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new create_model helper function is a great addition for simplifying test setup. However, it currently doesn't allow passing arguments to AutoConfig.from_pretrained, which is necessary for models that require trust_remote_code=True (as seen in test_deepseekv3.py).

To make this helper more versatile and promote consistency across tests, consider adding an auto_config_kwargs parameter to both create_model and load_model. This would allow all tests to use this common helper, avoiding duplicated model creation logic.

def create_model(
    model_name: str,
    config_cls: type[ModelConfig],
    model_cls: type[ModelForCausalLM],
    mesh_axes: tuple[str, ...],
    *,
    mesh_shape: tuple[int, ...] | None = None,
    mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
    seed: int = 0,
    auto_config_kwargs: dict[str, Any] | None = None,
    **config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
    """Create a JAX model with initialized weights."""
    if auto_config_kwargs is None:
        auto_config_kwargs = {}
    base_config = AutoConfig.from_pretrained(model_name, **auto_config_kwargs)
    config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
    if mesh_shape is None:
        mesh_shape = (1,) * len(mesh_axes)
    if mesh_axis_types is None:
        mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
    mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types)
    with jax.set_mesh(mesh):
        model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
    return config, model


def load_model(
    model_name: str,
    config_cls: type[ModelConfig],
    model_cls: type[ModelForCausalLM],
    mesh_axes: tuple[str, ...],
    *,
    mesh_shape: tuple[int, ...] | None = None,
    mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
    seed: int = 0,
    auto_config_kwargs: dict[str, Any] | None = None,
    **config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
    """Create a JAX model and load weights from the HuggingFace cache."""
    config, model = create_model(
        model_name,
        config_cls,
        model_cls,
        mesh_axes,
        mesh_shape=mesh_shape,
        mesh_axis_types=mesh_axis_types,
        seed=seed,
        auto_config_kwargs=auto_config_kwargs,
        **config_kwargs,
    )
    weights_dir = resolve_model_path(model_name)
    load_safetensors(weights_dir, config, model)
    return config, model

Comment on lines +197 to +213
model_name = "yujiepan/deepseek-v3-tiny-random"
base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True)

batch_size, seq_len = 2, 8
mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)

results = {}
for use_checkpointing in [False, True]:
config = DeepseekV3Config(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
gradient_checkpointing=use_checkpointing,
)
with jax.set_mesh(mesh):
model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test duplicates model creation logic that is now available in the create_model helper function in conftest.py. To improve code reuse and consistency, this test should be refactored to use create_model.

Assuming the suggested change to create_model (to support auto_config_kwargs) is applied, this test can be simplified as follows:

    model_name = "yujiepan/deepseek-v3-tiny-random"

    batch_size, seq_len = 2, 8
    mesh_axes = ("fsdp", "ep", "tp")

    results = {}
    for use_checkpointing in [False, True]:
        config, model = create_model(
            model_name,
            DeepseekV3Config,
            DeepseekV3ForCausalLM,
            mesh_axes,
            mesh_shape=(1, 1, 1),
            auto_config_kwargs={"trust_remote_code": True},
            max_lora_adapters=1,
            max_lora_rank=1,
            gradient_checkpointing=use_checkpointing,
            seed=0,
        )

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

@pcmoritz pcmoritz merged commit 8e7358b into NovaSky-AI:main Feb 16, 2026
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant