Skip to content

Port https://github.com/NovaSky-AI/SkyRL/pull/1079 to skyrl folder#1127

Merged
pcmoritz merged 3 commits intoNovaSky-AI:mainfrom
pcmoritz:port-stacked-pr
Feb 15, 2026
Merged

Port https://github.com/NovaSky-AI/SkyRL/pull/1079 to skyrl folder#1127
pcmoritz merged 3 commits intoNovaSky-AI:mainfrom
pcmoritz:port-stacked-pr

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 15, 2026

See #1079


Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring to improve performance by using stacked decoder layers. The use of jax.lax.scan for training and a Python loop for decoding is a smart optimization. The ArrayRef class provides a clever and clean way to handle stacked weights, simplifying checkpointing and adapter logic. The overall changes make the codebase more efficient and maintainable. I've found one high-severity bug in the new unstack_state utility function and have provided a detailed comment with a suggested fix.

Comment on lines +298 to +326
def unstack_state(module: nnx.Module) -> nnx.GraphState:
"""Transform stacked layer state to unstacked ArrayRef views.

Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc.
Each entry is an ArrayRef that writes through to the original stacked variable.

This is useful for checkpoint loading where weights are stored per-layer.

Args:
module: Module containing StackedDecoderLayers.

Returns:
GraphState with unstacked paths and ArrayRef views.
"""
state = nnx.state(module)
expanded = []

# Delegate to layers if they support unstacking
if hasattr(module, "model") and hasattr(module.model, "layers"):
layers = module.model.layers
if isinstance(layers, StackedDecoderLayers):
expanded.extend(layers.unstack_paths(state, base_path=("model", "layers")))

# Keep all non-stacked paths as-is
for path, param in nnx.to_flat_state(state):
if "_stacked" not in path:
expanded.append((path, param))

return nnx.from_flat_state(expanded)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function has a bug that results in an incorrect model state. The check for _stacked in a path on line 323 is incorrect because path is a tuple of objects, not strings. This causes the condition to always be true, leading to all original parameters being added to the expanded list, which already contains the unstacked ArrayRef views. This results in a state with both the original stacked parameters and the new unstacked views, which is incorrect.

The corrected version below fixes this by properly identifying and excluding the original stacked parameters. I've also added a TODO to note that the discovery of StackedDecoderLayers could be made more generic instead of using a hardcoded path.

def unstack_state(module: nnx.Module) -> nnx.GraphState:
    """Transform stacked layer state to unstacked ArrayRef views.

    Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc.
    Each entry is an ArrayRef that writes through to the original stacked variable.

    This is useful for checkpoint loading where weights are stored per-layer.

    Args:
        module: Module containing StackedDecoderLayers.

    Returns:
        GraphState with unstacked paths and ArrayRef views.
    """
    state = nnx.state(module)
    expanded = []

    # Helper to check for stacked paths, avoiding circular import.
    def _is_stacked_path(p: tuple) -> bool:
        path_strs = [i.key if hasattr(i, "key") else str(i) for i in p]
        return "_stacked" in path_strs

    # TODO: Generalize this to find all StackedDecoderLayers instances automatically.
    if hasattr(module, "model") and hasattr(module.model, "layers"):
        layers = module.model.layers
        if isinstance(layers, StackedDecoderLayers):
            expanded.extend(layers.unstack_paths(state, base_path=("model", "layers")))

    # Add all parameters that are NOT part of any stacked layer.
    for path, param in nnx.to_flat_state(state):
        if not _is_stacked_path(path):
            expanded.append((path, param))

    return nnx.from_flat_state(expanded)

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines +114 to +117
if "lora_A" in path:
return (adapter_index, slice(None), slice(None, rank))
if "lora_B" in path:
return (adapter_index, slice(None, rank), slice(None))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 get_lora_adapter_slice generates wrong index tuple for 4D expert LoRA params

When saving or loading LoRA checkpoints for MoE models, get_lora_adapter_slice produces incorrect index tuples for expert LoRA parameters (e.g., LoRAExpert), which have shape (num_adapters, num_experts, in_features, max_rank) for lora_A and (num_adapters, num_experts, max_rank, out_features) for lora_B.

Root Cause

The function always returns a 3-element tuple:

  • lora_A: (adapter_index, slice(None), slice(None, rank))
  • lora_B: (adapter_index, slice(None, rank), slice(None))

This is correct for non-expert 3D LoRA params (adapters, in_features, max_rank), but wrong for 4D expert LoRA params (adapters, num_experts, in_features, max_rank).

For a 4D lora_A with shape (adapters, experts, in_features, max_rank):

  • arr[adapter_index, :, :rank] selects arr[adapter_index](experts, in_features, max_rank), then [:, :rank] slices the in_features dim to rank → result shape (experts, rank, max_rank)
  • Should be arr[adapter_index, :, :, :rank] → shape (experts, in_features, rank)

Similarly for lora_B with shape (adapters, experts, max_rank, out_features):

  • arr[adapter_index, :rank, :] slices the experts dim to rank → result shape (rank, max_rank, out_features)
  • Should be arr[adapter_index, :, :rank, :] → shape (experts, rank, out_features)

Impact: save_lora_checkpoint and load_lora_checkpoint (at skyrl/skyrl/tx/utils/models.py:269-298) call save_safetensors/load_safetensors with adapter_index and rank set, which triggers this code path. For any MoE model (e.g., Qwen3 with experts), saving or loading LoRA checkpoints will silently corrupt expert LoRA weights by slicing along the wrong dimensions.

Suggested change
if "lora_A" in path:
return (adapter_index, slice(None), slice(None, rank))
if "lora_B" in path:
return (adapter_index, slice(None, rank), slice(None))
if "lora_A" in path:
return (adapter_index, ..., slice(None, rank))
if "lora_B" in path:
return (adapter_index, ..., slice(None, rank), slice(None))
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 10 additional findings in Devin Review.

Open in Devin Review

Comment on lines +167 to +168
if key not in tensors:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Early key not in tensors check silently skips all expert (MoE) weights during loading

When loading weights for MoE models (e.g., DeepSeek, Qwen3-MoE), load_safetensors silently skips all expert parameters because the combined key (e.g., model.layers.0.mlp.experts.gate_proj.weight) doesn't exist in the checkpoint tensors dict. HuggingFace checkpoints store per-expert keys like model.layers.0.mlp.experts.0.gate_proj.weight.

Root Cause

The new if key not in tensors: continue check at skyrl/skyrl/tx/utils/models.py:167 runs before the if "experts" in path: branch at line 169 that constructs the tensor from per-expert keys via get_expert_key. In the old code, there was no such early-exit check, so the expert branch was always reached.

For example, given path ("model", "layers", "0", "mlp", "experts", "gate_proj", "weight"):

  • get_param_key(path)model.layers.0.mlp.experts.gate_proj.weight — this key does NOT exist in the tensors dict
  • The code hits continue, never reaching get_expert_key(path, i) which would produce model.layers.0.mlp.experts.0.gate_proj.weight — which DOES exist

Impact: All expert weights (base weights AND LoRA weights) for MoE models will be silently left as their initialization values (zeros/random), producing completely wrong model outputs. This affects both base model loading (skyrl/skyrl/backends/jax.py:197) and LoRA checkpoint loading.

Prompt for agents
In skyrl/skyrl/tx/utils/models.py, in the load_safetensors function around lines 167-170, the check `if key not in tensors: continue` must be moved AFTER the expert weight handling block. For expert paths, the combined key won't be in tensors, but per-expert keys (obtained via get_expert_key) will be. The fix should restructure the logic so that for expert paths, the per-expert keys are checked instead. For example:

if "experts" in path:
    expert_key = get_expert_key(path, 0)
    if expert_key not in tensors:
        continue
    tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
elif key not in tensors:
    continue
else:
    tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@pcmoritz pcmoritz merged commit 6b711d7 into NovaSky-AI:main Feb 15, 2026
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant