Port https://github.com/NovaSky-AI/SkyRL/pull/1079 to skyrl folder#1127
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-executed refactoring to improve performance by using stacked decoder layers. The use of jax.lax.scan for training and a Python loop for decoding is a smart optimization. The ArrayRef class provides a clever and clean way to handle stacked weights, simplifying checkpointing and adapter logic. The overall changes make the codebase more efficient and maintainable. I've found one high-severity bug in the new unstack_state utility function and have provided a detailed comment with a suggested fix.
| def unstack_state(module: nnx.Module) -> nnx.GraphState: | ||
| """Transform stacked layer state to unstacked ArrayRef views. | ||
|
|
||
| Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc. | ||
| Each entry is an ArrayRef that writes through to the original stacked variable. | ||
|
|
||
| This is useful for checkpoint loading where weights are stored per-layer. | ||
|
|
||
| Args: | ||
| module: Module containing StackedDecoderLayers. | ||
|
|
||
| Returns: | ||
| GraphState with unstacked paths and ArrayRef views. | ||
| """ | ||
| state = nnx.state(module) | ||
| expanded = [] | ||
|
|
||
| # Delegate to layers if they support unstacking | ||
| if hasattr(module, "model") and hasattr(module.model, "layers"): | ||
| layers = module.model.layers | ||
| if isinstance(layers, StackedDecoderLayers): | ||
| expanded.extend(layers.unstack_paths(state, base_path=("model", "layers"))) | ||
|
|
||
| # Keep all non-stacked paths as-is | ||
| for path, param in nnx.to_flat_state(state): | ||
| if "_stacked" not in path: | ||
| expanded.append((path, param)) | ||
|
|
||
| return nnx.from_flat_state(expanded) |
There was a problem hiding this comment.
This function has a bug that results in an incorrect model state. The check for _stacked in a path on line 323 is incorrect because path is a tuple of objects, not strings. This causes the condition to always be true, leading to all original parameters being added to the expanded list, which already contains the unstacked ArrayRef views. This results in a state with both the original stacked parameters and the new unstacked views, which is incorrect.
The corrected version below fixes this by properly identifying and excluding the original stacked parameters. I've also added a TODO to note that the discovery of StackedDecoderLayers could be made more generic instead of using a hardcoded path.
def unstack_state(module: nnx.Module) -> nnx.GraphState:
"""Transform stacked layer state to unstacked ArrayRef views.
Converts paths like `layers._stacked.xxx` to `layers.0.xxx`, `layers.1.xxx`, etc.
Each entry is an ArrayRef that writes through to the original stacked variable.
This is useful for checkpoint loading where weights are stored per-layer.
Args:
module: Module containing StackedDecoderLayers.
Returns:
GraphState with unstacked paths and ArrayRef views.
"""
state = nnx.state(module)
expanded = []
# Helper to check for stacked paths, avoiding circular import.
def _is_stacked_path(p: tuple) -> bool:
path_strs = [i.key if hasattr(i, "key") else str(i) for i in p]
return "_stacked" in path_strs
# TODO: Generalize this to find all StackedDecoderLayers instances automatically.
if hasattr(module, "model") and hasattr(module.model, "layers"):
layers = module.model.layers
if isinstance(layers, StackedDecoderLayers):
expanded.extend(layers.unstack_paths(state, base_path=("model", "layers")))
# Add all parameters that are NOT part of any stacked layer.
for path, param in nnx.to_flat_state(state):
if not _is_stacked_path(path):
expanded.append((path, param))
return nnx.from_flat_state(expanded)| if "lora_A" in path: | ||
| return (adapter_index, slice(None), slice(None, rank)) | ||
| if "lora_B" in path: | ||
| return (adapter_index, slice(None, rank), slice(None)) |
There was a problem hiding this comment.
🔴 get_lora_adapter_slice generates wrong index tuple for 4D expert LoRA params
When saving or loading LoRA checkpoints for MoE models, get_lora_adapter_slice produces incorrect index tuples for expert LoRA parameters (e.g., LoRAExpert), which have shape (num_adapters, num_experts, in_features, max_rank) for lora_A and (num_adapters, num_experts, max_rank, out_features) for lora_B.
Root Cause
The function always returns a 3-element tuple:
lora_A:(adapter_index, slice(None), slice(None, rank))lora_B:(adapter_index, slice(None, rank), slice(None))
This is correct for non-expert 3D LoRA params (adapters, in_features, max_rank), but wrong for 4D expert LoRA params (adapters, num_experts, in_features, max_rank).
For a 4D lora_A with shape (adapters, experts, in_features, max_rank):
arr[adapter_index, :, :rank]selectsarr[adapter_index]→(experts, in_features, max_rank), then[:, :rank]slices thein_featuresdim torank→ result shape(experts, rank, max_rank)✗- Should be
arr[adapter_index, :, :, :rank]→ shape(experts, in_features, rank)✓
Similarly for lora_B with shape (adapters, experts, max_rank, out_features):
arr[adapter_index, :rank, :]slices theexpertsdim torank→ result shape(rank, max_rank, out_features)✗- Should be
arr[adapter_index, :, :rank, :]→ shape(experts, rank, out_features)✓
Impact: save_lora_checkpoint and load_lora_checkpoint (at skyrl/skyrl/tx/utils/models.py:269-298) call save_safetensors/load_safetensors with adapter_index and rank set, which triggers this code path. For any MoE model (e.g., Qwen3 with experts), saving or loading LoRA checkpoints will silently corrupt expert LoRA weights by slicing along the wrong dimensions.
| if "lora_A" in path: | |
| return (adapter_index, slice(None), slice(None, rank)) | |
| if "lora_B" in path: | |
| return (adapter_index, slice(None, rank), slice(None)) | |
| if "lora_A" in path: | |
| return (adapter_index, ..., slice(None, rank)) | |
| if "lora_B" in path: | |
| return (adapter_index, ..., slice(None, rank), slice(None)) | |
Was this helpful? React with 👍 or 👎 to provide feedback.
| if key not in tensors: | ||
| continue |
There was a problem hiding this comment.
🔴 Early key not in tensors check silently skips all expert (MoE) weights during loading
When loading weights for MoE models (e.g., DeepSeek, Qwen3-MoE), load_safetensors silently skips all expert parameters because the combined key (e.g., model.layers.0.mlp.experts.gate_proj.weight) doesn't exist in the checkpoint tensors dict. HuggingFace checkpoints store per-expert keys like model.layers.0.mlp.experts.0.gate_proj.weight.
Root Cause
The new if key not in tensors: continue check at skyrl/skyrl/tx/utils/models.py:167 runs before the if "experts" in path: branch at line 169 that constructs the tensor from per-expert keys via get_expert_key. In the old code, there was no such early-exit check, so the expert branch was always reached.
For example, given path ("model", "layers", "0", "mlp", "experts", "gate_proj", "weight"):
get_param_key(path)→model.layers.0.mlp.experts.gate_proj.weight— this key does NOT exist in the tensors dict- The code hits
continue, never reachingget_expert_key(path, i)which would producemodel.layers.0.mlp.experts.0.gate_proj.weight— which DOES exist
Impact: All expert weights (base weights AND LoRA weights) for MoE models will be silently left as their initialization values (zeros/random), producing completely wrong model outputs. This affects both base model loading (skyrl/skyrl/backends/jax.py:197) and LoRA checkpoint loading.
Prompt for agents
In skyrl/skyrl/tx/utils/models.py, in the load_safetensors function around lines 167-170, the check `if key not in tensors: continue` must be moved AFTER the expert weight handling block. For expert paths, the combined key won't be in tensors, but per-expert keys (obtained via get_expert_key) will be. The fix should restructure the logic so that for expert paths, the per-expert keys are checked instead. For example:
if "experts" in path:
expert_key = get_expert_key(path, 0)
if expert_key not in tensors:
continue
tensor = np.stack([tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
elif key not in tensors:
continue
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
Was this helpful? React with 👍 or 👎 to provide feedback.
See #1079