[tx] General implementation of trainable Hyper Connections#1008
[tx] General implementation of trainable Hyper Connections#1008pcmoritz merged 51 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.
My review found a couple of issues:
- An unused
trainableparameter in theConnectorclass which should be removed for clarity. - A bug in
DeepseekV3Modelwhen handling intermediate hidden states forexpansion_rate > 1, wheresqueeze()is used incorrectly.
Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
| all_hidden_states.append(hidden_states.squeeze()) | |
| all_hidden_states.append(hidden_states.mean(axis=-2)) |
skyrl-tx/tx/layers/layernorm.py
Outdated
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs |
There was a problem hiding this comment.
Temporary, testing
There was a problem hiding this comment.
https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
Torch also initalizes to one by default
There was a problem hiding this comment.
Due to adapter indexing, ended up re-implementing norm in the connector layer itself - this change can be removed. But considering torch as the baseline, ones_init fits better still
|
This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :) |
|
Just waiting for the weekend to give it a spin 😅 I'll give Qwen0.6B a shot on an A/H100 |
|
Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :) |
|
Did some analysis on the step times for each on Qwen 0.6B (on a 5060Ti) Expansion rate as 1 does cause a hit to the average step time (about 0.3s slower, baseline has a step time of 2.1s vs 2.4s). An easy fix would be to just short circuit the entire thing for expansion rate = 1. For expansion rate = 4, the step time was around 3.17s, so about 46% slower. |
skyrl-tx/tx/tinker/backends/jax.py
Outdated
| """Compute full gradients, apply optimizer update, and reset accumulated grads.""" | ||
| optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index)) | ||
| return accumulated_grads.reset_adapter(adapter_index) | ||
| if global_optimizer is not None and self.has_global_trainables: | ||
| global_optimizer.update(global_params, global_accumulated_grads.get_mean()) | ||
| global_accumulated_grads = global_accumulated_grads.reset() | ||
| return accumulated_grads.reset_adapter(adapter_index), global_accumulated_grads |
There was a problem hiding this comment.
🔴 Global optimizer updated with zero gradients on second adapter's optim_step
When multiple LoRA adapters are active, the shared global optimizer receives spurious zero-gradient updates, corrupting its Adam state.
Root Cause
In compute_grads_and_update (jax.py:531-536), the global optimizer is updated and the global accumulated gradients are reset unconditionally on every call:
if global_optimizer is not None and self.has_global_trainables:
global_optimizer.update(global_params, global_accumulated_grads.get_mean())
global_accumulated_grads = global_accumulated_grads.reset()Since optim_step is called once per adapter (jax.py:773-809), with two adapters the sequence is:
optim_step(adapter_1)→ updates global optimizer with real mean gradients, resetsglobal_accumulated_gradsto zerooptim_step(adapter_2)→ updates global optimizer again withget_mean()of the now-zeroed gradients (all zeros), resets again
The second zero-gradient update corrupts Adam's internal state:
- First moments decay:
m_t = β₁ · m_{t-1} + (1-β₁) · 0— momentum decays toward zero - Second moments decay:
v_t = β₂ · v_{t-1} + (1-β₂) · 0— variance estimate shrinks - Step counter increments, affecting bias correction
Impact: Global trainable parameters (connectors) receive incorrect optimizer updates that degrade training quality, with severity proportional to the number of adapters.
Prompt for agents
The global optimizer should only be updated once per training iteration, not once per adapter. Currently in compute_grads_and_update (jax.py:531-536), the global optimizer is updated and global accumulated gradients are reset on every call, but optim_step is called once per adapter. Fix this by either: (1) tracking whether global grads have already been applied in this iteration and skipping if already done (e.g., check global_accumulated_grads.count > 0 before updating), or (2) decoupling the global optimizer step from the per-adapter optim_step so it runs exactly once per training iteration. Option (1) is simpler: guard the global optimizer update with a check like `if global_accumulated_grads.count > 0` before calling global_optimizer.update.
Was this helpful? React with 👍 or 👎 to provide feedback.
| def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array: | ||
| if adapter_indices is None: | ||
| return jnp.zeros((batch_size,), dtype=jnp.int32) | ||
| return adapter_indices.astype(jnp.int32) |
There was a problem hiding this comment.
🟡 LoRAConnector broken when max_lora_adapters=0 — indexing into 0-sized parameter arrays returns wrong values
When a model is created with max_lora_adapters=0 (e.g., tx/run/train.py:80), the LoRAConnector creates all parameter arrays with a first dimension of 0. When pre() or post() is called, _get_adapter_indices returns jnp.zeros((B,), dtype=jnp.int32), and _get_params indexes into these 0-sized arrays, producing zero-filled results instead of the identity-preserving values.
Detailed Explanation
Unlike LoRAMixin.apply_lora which short-circuits when max_lora_adapters == 0 (lora.py:85), LoRAConnector has no such guard. When max_lora_adapters=0:
self.b_prehas shape(0, n),self.b_reshas shape(0, n, n), etc._get_adapter_indices(B, None)returnsjnp.zeros((B,))atconnectors.py:66_get_paramsindexes into 0-sized arrays atconnectors.py:71-80— JAX clips out-of-bounds indices and returns zeros- In
pre():b_pre=0→H_pre = sigmoid(0) = 0.5instead of1/n - In
post():b_res=0→M = sinkhorn(zeros)produces a uniform1/nmatrix instead of identity
For the default expansion_rate=1, the impact on pre is masked by RMSNorm (the 0.5 scale cancels during normalization), and post still produces the correct residual + output. So the default case is approximately correct. However, for expansion_rate > 1 with max_lora_adapters=0, the connector would produce completely wrong outputs (uniform mixing instead of identity passthrough).
This path is exercised in production via tx/run/train.py:80 which uses max_lora_adapters=0.
Prompt for agents
Add a guard in LoRAConnector to handle the max_lora_adapters=0 case. The simplest approach is to add a check at the start of pre() and post() methods that bypasses the connector logic when max_lora_adapters is 0, falling back to identity behavior: pre() should return x.sum(axis=-2) / n (or equivalently the mean), and post() should return residual + output[..., None, :] (broadcasting output into the expansion dimension). Alternatively, ensure the constructor always creates at least 1 adapter slot (with identity initialization) even when max_lora_adapters=0, similar to how the default adapter_index=0 is used when adapter_indices is None.
Was this helpful? React with 👍 or 👎 to provide feedback.
This is in preparation for merging #1008 and to make it easier to introduce metrics. <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1191" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
|
Thanks a lot for all the updates, I'll do the rest (already merged a PR that cleans things up a little #1191) :) |
skyrl-tx/tx/layers/connectors.py
Outdated
| C = hidden_dim | ||
|
|
||
| # Phi matrices are zero-initialized so that alpha * x @ 0 + bias = bias at init. | ||
| self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=dtype)) |
There was a problem hiding this comment.
I'm curious, why did you make the RMSNorm per adapter and trainable? That seems wrong, we should probably just use the RMSNorm from the base model :) [I don't think any of the LoRA codes out there make the RMSNorm trainable]
There was a problem hiding this comment.
Actually I think I misunderstood the code and you are doing the right thing :)
There was a problem hiding this comment.
Sorry for going back on forth on this, but I think the actually correct implementation would be to pass the input norm parameters from the model to the constructor of LoRAConnector and use it for the normalization below and keep it non-trainable. It will be slightly redundant to apply the norm twice, but I think for code clarity that's fine for now (there is more optimizations to do anyways). Let me know about your thoughts, I'll give that a shot :)
There was a problem hiding this comment.
Thanks for the change!
My layernorm change was around making the entire block including the norms trainable - but yeah if its something like lora that shouldn't be the case
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a general implementation of trainable Hyper Connections (mHC) as an extension to LoRA. The changes are extensive, touching model configurations, layer implementations, backend logic, and utility functions for checkpointing. The core logic resides in the new LoRAConnector module, and its integration into models like DeepseekV3 and Qwen3 appears correct, properly handling the new stream dimension. The utility functions for state management and checkpointing have also been updated to support these new connector parameters. The tests provide a good foundation, but I have identified a few areas for improvement to enhance their robustness and address minor issues in the implementation.
| logits_e1 = np.asarray(model_e1.compute_logits(outputs_e1.last_hidden_state)) | ||
| logits_e4 = np.asarray(model_e4.compute_logits(outputs_e4.last_hidden_state)) | ||
|
|
||
| np.testing.assert_allclose(logits_e1, logits_e4, rtol=5e-2, atol=5e-2) |
There was a problem hiding this comment.
The tolerance for this assert_allclose is set to 5e-2 (5%), which is quite high for a test that aims to verify that the initial connector behavior keeps logits unchanged. This high tolerance might mask subtle deviations from the expected identity mapping. Consider lowering the tolerance (e.g., to 1e-5 or 1e-6) to ensure the identity initialization is working as precisely as intended.
| class _TinyConnector(nnx.Module): | ||
| def __init__(self, max_adapters: int): | ||
| self.alpha_pre = nnx.Param(jnp.zeros((max_adapters, 4), dtype=jnp.float32)) | ||
| self.phi_pre = nnx.Param(jnp.zeros((max_adapters, 4, 2), dtype=jnp.float32)) |
There was a problem hiding this comment.
The _TinyConnector mock is incomplete and only contains a subset of the parameters from the actual LoRAConnector. This means that tests relying on this mock (like test_connector_adapter_slice_save_load_safetensors and test_connector_extract_insert_adapter_state_roundtrip) are not comprehensively verifying the serialization and state management logic for all connector parameters (e.g., b_pre, b_post, b_res, phi_post, phi_res, etc.).
To improve test coverage and ensure correctness, please expand _TinyConnector to include all parameters present in LoRAConnector and update the corresponding tests to check these additional parameters.
| return x[..., 0, :], x.reshape(B, T, n * C) | ||
|
|
||
| adapter_indices = self._get_adapter_indices(B, adapter_indices) | ||
| # Apply input_norm independently to each of the n streams. |
There was a problem hiding this comment.
The paper is not super clear on whether this is the right way to do it -- below equation (5) it says RMSNorm is applied to the last dimension C. In equation (7) it looks more like the RMSNorm is applied on the full n * C dimension. I chose the interpretation according to equation (5) since it is slightly more elegant and doesn't require changing the definition of the RMSNorm. Once a DeepSeek model is released that supports mHC, we can revisit this.
There was a problem hiding this comment.
I think HC applies the RMSNorm to the last dimension C and mHC applies it to n * C, this becomes pretty apparent from equations (15) and (16). We should switch as soon as we have a model that natively supports mHC.
There was a problem hiding this comment.
Did you observe anything off with trainable norms for n*C? Just curious as to why that wont fit
There was a problem hiding this comment.
The actual difference in performance is pretty small, and once we have a model that is trained with mHC, it will have an RMSNorm weight of size n * C, so it will be very easy to adapt the current code to it (and there is no need for a trainable norm in the LoRA setting), so I feel like that's the better solution for now :)
There was a problem hiding this comment.
Sounds good. For already pretrained models, the original HC paper is probably a better fit than mHC as it does have experimentations with HC as an augmentation (unlike mHC).
So RMS over C is probably better for the general case.
See #1008 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1217" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->




Addresses #952
This PR is a general implementation of Hyper connections.
This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.
Default case - Trainable is false. Expansion rate is 1.
[edit] we now bypass this case entirely for a regular residual network.
For expansion rate > 1
These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.
Todos
Future work