[tx] Port https://github.com/NovaSky-AI/SkyRL/pull/1008 to skyrl folder#1217
Conversation
There was a problem hiding this comment.
Code Review
This pull request ports changes to add support for Manifold constrained HyperConnections (mHC). The changes are extensive, introducing a new LoRAConnector layer and integrating it into various parts of the system, including model configurations, model implementations, and checkpointing utilities. The PR also includes a comprehensive set of new tests for the added functionality.
I've identified one critical issue regarding parameter filtering that needs to be addressed, along with a couple of medium-severity suggestions to improve code clarity and maintainability. Overall, the changes are well-structured, but the identified issues should be resolved.
| is_lora = any(name in path for name in ("lora_A", "lora_B")) | ||
| is_connector = self.config.mhc_expansion_rate > 1 and any( | ||
| name in path for name in ("attn_connector", "mlp_connector") | ||
| ) | ||
| return is_lora or is_connector |
There was a problem hiding this comment.
The path argument received by this method is a tuple of nnx.path.PathEntry objects, not strings. Therefore, using the in operator like name in path will not work as intended to check for the presence of a key. This will cause trainable LoRA and connector parameters to not be identified correctly.
You should normalize the path to a tuple of strings (the keys from the PathEntry objects) before checking for containment, similar to how is_connector_path is implemented.
| is_lora = any(name in path for name in ("lora_A", "lora_B")) | |
| is_connector = self.config.mhc_expansion_rate > 1 and any( | |
| name in path for name in ("attn_connector", "mlp_connector") | |
| ) | |
| return is_lora or is_connector | |
| normalized_path = tuple(p.key for p in path if hasattr(p, "key")) | |
| is_lora = any(name in normalized_path for name in ("lora_A", "lora_B")) | |
| is_connector = self.config.mhc_expansion_rate > 1 and any( | |
| name in normalized_path for name in ("attn_connector", "mlp_connector") | |
| ) | |
| return is_lora or is_connector |
|
|
||
| class LoRAConnector(nnx.Module): | ||
| """ | ||
| Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880) |
There was a problem hiding this comment.
The docstring for LoRAConnector refers to an arXiv paper https://arxiv.org/pdf/2512.24880. This appears to be a placeholder, as the publication date would be in the future (December 2025). Please update this to the correct link when available, or remove it if it's not intended to be a real reference.
| hidden_states, residual_norm = self.attn_connector.pre(hidden_states, self.input_layernorm, adapter_indices) | ||
| hidden_states = self.input_layernorm(hidden_states) |
There was a problem hiding this comment.
The input_layernorm is applied inside attn_connector.pre to determine routing based on the normalized input, and then it's applied again to the aggregated output of pre before passing it to the attention block. While functionally correct, this pattern is a bit hard to follow. Consider adding a comment to clarify the data flow, or refactoring the pre method's signature for better readability.
| hidden_states, residual_norm = self.attn_connector.pre(hidden_states, self.input_layernorm, adapter_indices) | ||
| hidden_states = self.input_layernorm(hidden_states) |
There was a problem hiding this comment.
The input_layernorm is applied inside attn_connector.pre to determine routing based on the normalized input, and then it's applied again to the aggregated output of pre before passing it to the attention block. While functionally correct, this pattern is a bit hard to follow. Consider adding a comment to clarify the data flow, or refactoring the pre method's signature for better readability.
See #1008