Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
57d1881
Initial design
tanmaysachan Feb 2, 2026
98c0994
Merge branch 'main' into tanmay/mhc
tanmaysachan Feb 2, 2026
91d5e74
Add comment
tanmaysachan Feb 2, 2026
24b82d7
Identity mapping for initial passthrough
tanmaysachan Feb 2, 2026
874ab08
Add trainable flag for freezing weights
tanmaysachan Feb 2, 2026
975faa1
Stray comment
tanmaysachan Feb 2, 2026
f685543
simplify
tanmaysachan Feb 2, 2026
e493ae5
Add elementwise_affine flag to RMS to match pytorch impl. Replace raw…
tanmaysachan Feb 3, 2026
b4ad7ad
Merge branch 'main' into tanmay/mhc
tanmaysachan Feb 10, 2026
066af09
Add to qwen, restore norms
tanmaysachan Feb 11, 2026
bcd4e41
remove rms changes
tanmaysachan Feb 12, 2026
744ce19
make mhc trainable
tanmaysachan Feb 12, 2026
9f88cc5
jax backend alternate path for global trainables
tanmaysachan Feb 12, 2026
495bb38
Make connectors part of Lora training flow
tanmaysachan Feb 12, 2026
c204a38
stray change
tanmaysachan Feb 12, 2026
7a2e921
how even
tanmaysachan Feb 12, 2026
c9b3b93
resolve conflicts with main
tanmaysachan Feb 17, 2026
587a3bf
Make mhc compatible with stacked decoder layers
tanmaysachan Feb 17, 2026
6798eae
Simplifications
tanmaysachan Feb 17, 2026
0b8ed0a
unused import
tanmaysachan Feb 17, 2026
7c02962
Address comments
tanmaysachan Feb 17, 2026
cea717b
Merge remote-tracking branch 'upstream/main' into tanmay/mhc
tanmaysachan Feb 18, 2026
8321e5d
Address comments, add tests, add comments
tanmaysachan Feb 18, 2026
c14203b
Move expansion rate parity from deepseek to test_connector
tanmaysachan Feb 18, 2026
41add0d
Expose mHC flags to backend-config
tanmaysachan Feb 18, 2026
43c53d4
Old change, not needed
tanmaysachan Feb 18, 2026
4f42fe2
clear_lora_adapter resets to identity semantics instead of zeroing out
tanmaysachan Feb 18, 2026
01e3e3d
Merge branch 'main' into tanmay/mhc
pcmoritz Feb 18, 2026
a7b6925
update
pcmoritz Feb 19, 2026
dcc241e
update
pcmoritz Feb 20, 2026
cbbff6f
cleanup clearing/init to be connector specific
tanmaysachan Feb 20, 2026
ac6c81b
Update alpha values to have 0.1 default
tanmaysachan Feb 20, 2026
3fcb413
Merge branch 'main' into tanmay/mhc
tanmaysachan Feb 20, 2026
5260f38
Merge branch 'main' into tanmay/mhc
pcmoritz Feb 20, 2026
9353ecb
update
pcmoritz Feb 20, 2026
8086212
update
pcmoritz Feb 20, 2026
cb7a559
update
pcmoritz Feb 21, 2026
7752d22
update
pcmoritz Feb 21, 2026
5da4994
cleanup
pcmoritz Feb 21, 2026
07e5cce
update
pcmoritz Feb 21, 2026
f7a6756
simplify
pcmoritz Feb 22, 2026
a0a3ce0
unify defaults
pcmoritz Feb 22, 2026
96089f3
update
pcmoritz Feb 22, 2026
4e60734
update
pcmoritz Feb 22, 2026
ac9b815
simplify
pcmoritz Feb 22, 2026
644b714
update
pcmoritz Feb 22, 2026
209dccc
do not reorder
pcmoritz Feb 22, 2026
72d4369
update
pcmoritz Feb 22, 2026
df3501f
update
pcmoritz Feb 22, 2026
c546a8a
canonicalize configuration
pcmoritz Feb 23, 2026
74895d0
rename
pcmoritz Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions skyrl-tx/tx/layers/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Connection mechanisms for transformer layers (residual, learned connectors, etc.)."""

from flax import nnx
import jax
from jax import numpy as jnp

from tx.layers.util import Param
from tx.layers.layernorm import RMSNorm


class Connector(nnx.Module):
"""General implementation of (m?)Hyper Connections"""

def __init__(
self,
hidden_dim: int,
expansion_rate: int,
*,
trainable: bool = False,
sinkhorn_iters: int = 20,
eps: float = 1e-5,
dtype: jnp.dtype,
rngs: nnx.Rngs,
) -> None:
self.hidden_dim = hidden_dim
self.expansion_rate = expansion_rate
self.trainable = trainable
self.sinkhorn_iters = sinkhorn_iters
self.eps = eps
n = expansion_rate
C = hidden_dim

self.norm = RMSNorm(hidden_dim, eps=eps, dtype=dtype, rngs=rngs)

self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)
self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)
self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs)

self.b_pre = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init())
self.b_post = Param(n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init())
self.b_res = Param(n, n, dtype=dtype, rngs=rngs, kernel_init=nnx.initializers.zeros_init())

self.alpha_pre = nnx.Param(jnp.array(0.01, dtype=dtype))
self.alpha_post = nnx.Param(jnp.array(0.01, dtype=dtype))
self.alpha_res = nnx.Param(jnp.array(0.01, dtype=dtype))

def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array:
M = jnp.exp(M)
for _ in range(self.sinkhorn_iters):
M = M / (M.sum(axis=-1, keepdims=True) + self.eps)
M = M / (M.sum(axis=-2, keepdims=True) + self.eps)
return M

def pre(self, x: jax.Array) -> jax.Array:
*batch_dims, n, C = x.shape

x_flat = x.reshape(*batch_dims, n * C)
rms = jnp.sqrt(jnp.mean(x_flat * x_flat, axis=-1, keepdims=True) + self.eps)
x_norm = x_flat / rms

tilde_H_pre = self.alpha_pre[...] * (x_norm @ self.phi_pre[...]) + self.b_pre[...]
tilde_H_post = self.alpha_post[...] * (x_norm @ self.phi_post[...]) + self.b_post[...]
tilde_H_res = self.alpha_res[...] * (x_norm @ self.phi_res[...]).reshape(*batch_dims, n, n) + self.b_res[...]

H_pre = jax.nn.sigmoid(tilde_H_pre)
self._H_post = 2.0 * jax.nn.sigmoid(tilde_H_post)
self._M = self._sinkhorn_knopp(tilde_H_res)

x_agg = jnp.einsum("...i,...ic->...c", H_pre, x)
x_normed = self.norm(x_agg)

return x_normed

def post(self, residual: jax.Array, output: jax.Array) -> jax.Array:
y_dist = self._H_post[..., None] * output[..., None, :]
x_mixed = jnp.einsum("...ij,...jc->...ic", self._M, residual)
return x_mixed + y_dist
37 changes: 28 additions & 9 deletions skyrl-tx/tx/models/deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tx.layers.rotary_embedding import get_rope
from tx.layers.util import Param, prepare_routing, shard_map_ep
from tx.layers.layernorm import RMSNorm
from tx.layers.connectors import Connector
from tx.models.configs import DeepseekV3Config
from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache
Expand Down Expand Up @@ -417,17 +418,28 @@ def __call__(

class DeepseekV3DecoderLayer(nnx.Module):

def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
def __init__(
self,
config: DeepseekV3Config,
layer_idx: int,
*,
dtype: jnp.dtype,
rngs: nnx.Rngs,
expansion_rate: int = 1,
) -> None:
self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs)
self.layer_idx = layer_idx
self.num_layers = config.num_hidden_layers
self.expansion_rate = expansion_rate

# Use dense MLP for initial layers, MoE for the rest
if layer_idx >= config.first_k_dense_replace:
self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs)
else:
self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs)

self.attn_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs)
self.mlp_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs)

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -437,21 +449,28 @@ def __call__(
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
n = self.expansion_rate
if self.layer_idx == 0:
hidden_states = jnp.repeat(hidden_states[..., None, :], n, axis=-2)

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.attn_connector.pre(hidden_states)
hidden_states, updated_cache = self.self_attn(
hidden_states,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_connector.post(residual, hidden_states)

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp_connector.pre(hidden_states)
mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices)
hidden_states = residual + mlp_output
hidden_states = self.mlp_connector.post(residual, mlp_output)

if self.layer_idx == self.num_layers - 1:
hidden_states = hidden_states.sum(axis=-2)

return hidden_states, updated_cache

Expand Down Expand Up @@ -500,7 +519,7 @@ def __call__(

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))


hidden_states, (k, v) = layer(
hidden_states,
Expand Down
3 changes: 3 additions & 0 deletions skyrl-tx/tx/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def load_safetensors(
# Skip LoRA parameters if requested
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
# Skip connector parameters
if any("connector" in str(p) for p in path):
continue
if "experts" in path:
tensors[key] = np.stack(
[tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0
Expand Down
Loading