Skip to content

[tx] Fix tied embedding closure for TPUs#1233

Merged
pcmoritz merged 3 commits intoNovaSky-AI:mainfrom
pcmoritz:fix-tied-embedding-closure
Mar 1, 2026
Merged

[tx] Fix tied embedding closure for TPUs#1233
pcmoritz merged 3 commits intoNovaSky-AI:mainfrom
pcmoritz:fix-tied-embedding-closure

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented Feb 27, 2026

This PR fixes a problem discovered by @andrewsykim while working on #1024 -- the Jax TPU backend is more strict when closing over variables. The root cause is https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl/tx/layers/lora.py#L173 being closed over during (eager) model initialization time. Instead of closing over lm_head during initialization time, we do it when the model is evaluated, so it is properly JITted.

This is also a good fix for GPUs since it should reduce memory consumption.


Open with Devin

@pcmoritz pcmoritz added the tx label Feb 27, 2026
gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

@andrewsykim
Copy link
Copy Markdown
Contributor

This PR works on my multi-host TPU cluster (4x4 v6e) on GKE.

I used this script which starts coordinator node and workers:

#!/bin/bash

LABEL_SELECTOR="ray.io/cluster=raycluster-tpu-v6e-multihost,ray.io/node-type=worker" 
NAMESPACE="default"

PODS=$(kubectl get pods -l "$LABEL_SELECTOR" -n "$NAMESPACE" -o name)

if [ -z "$PODS" ]; then
    echo " No pods found matching label: $LABEL_SELECTOR"
    exit 1
fi

COORDINATOR_ADDRESS=""

echo ">> Searching for coordinator (TPU_WORKER_ID=0)..."
for POD in $PODS; do
    WORKER_ID=$(kubectl exec -n "$NAMESPACE" "$POD" -- sh -c 'echo $TPU_WORKER_ID')
    if [ "$WORKER_ID" == "0" ]; then
        COORDINATOR_ADDRESS=$(kubectl -n "$NAMESPACE" get "$POD" --template '{{.status.podIP}}')
        kubectl exec -n "$NAMESPACE" "$POD" -- \
            uv run --extra tpu --extra tinker --extra jax -m skyrl.tinker.api \
            --base-model "Qwen/Qwen3-0.6B" \
            --backend-config "{\"train_micro_batch_size\": 8, \"sample_max_num_sequences\": 256, \"tensor_parallel_size\": 4, \"fully_sharded_data_parallel_size\": 4, \"num_processes\": 4, \"coordinator_address\": \"$COORDINATOR_ADDRESS:7777\"}" &
        echo ">> Found Coordinator: $POD at $COORDINATOR_ADDRESS"
        sleep 60
        break
    fi
done

if [ -z "$COORDINATOR_ADDRESS" ]; then
    echo ">> Error: Could not find a pod with TPU_WORKER_ID=0. Exiting."
    exit 1
fi

for POD in $PODS; do
    PROCESS_ID=$(kubectl exec -n "$NAMESPACE" "$POD" -- sh -c 'echo $TPU_WORKER_ID')
    
    if [ "$PROCESS_ID" != "0" ]; then
        echo ">> Executing on $POD (Process ID: $PROCESS_ID)..."
        kubectl exec -n "$NAMESPACE" "$POD" -- \
            uv run --extra tpu --extra tinker --extra jax -m skyrl.backends.jax \
            --coordinator-address "$COORDINATOR_ADDRESS:7777" \
            --num-processes 4 \
            --process-id "$PROCESS_ID" &
    fi
    echo "-----------------------------------------------"
done
$ ./start_server.sh
>> Searching for coordinator (TPU_WORKER_ID=0)...
>> Found Coordinator: pod/raycluster-tpu-v6e-multihost-tpu-group-worker-z4s5l at 10.72.7.9
...
...
2026-02-27 14:48:18,649 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,654 - INFO - skyrl: Initialized TinkerEngine with
backend=JaxBackend
2026-02-27 14:48:18,655 - INFO - skyrl: Starting background engine...
2026-02-27 14:48:18,657 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,658 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,658 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,663 - INFO - skyrl: Worker process_id=1 entering command
loop
2026-02-27 14:48:18,662 - INFO - skyrl: Worker process_id=2 entering command
loop
2026-02-27 14:48:18,663 - INFO - skyrl: Worker process_id=3 entering command
loop

Then I ran the rl_loop.py example from the README on node0

$ uv run rl_loop.py
warning: The `extra-build-dependencies` option is experimental and may change without warning. Pass `--preview-features extra-build-dependencies` to disable this warning.
merges.txt: 1.67MB [00:00, 12.6MB/s]
Loss: 7.0421
Loss: 5.9501
Loss: 5.1374
Loss: 4.2922
Loss: 3.5262
Loss: 2.8256

Some relavant logs from server:

2026-02-27 14:48:18,649 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,654 - INFO - skyrl: Initialized TinkerEngine with
backend=JaxBackend
2026-02-27 14:48:18,655 - INFO - skyrl: Starting background engine...
2026-02-27 14:48:18,657 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,658 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,658 - INFO - skyrl: Initialized base model Qwen/Qwen3-0.6B
with max_lora_adapters=32, max_lora_rank=32
2026-02-27 14:48:18,663 - INFO - skyrl: Worker process_id=1 entering command
loop
2026-02-27 14:48:18,662 - INFO - skyrl: Worker process_id=2 entering command
loop
2026-02-27 14:48:18,663 - INFO - skyrl: Worker process_id=3 entering command
loop
2026-02-27 14:49:21,489 - INFO - uvicorn.access: 127.0.0.1:49076 - "POST
/api/v1/create_session HTTP/1.1" 200
2026-02-27 14:49:21,504 - INFO - uvicorn.access: 127.0.0.1:49088 - "POST
/api/v1/create_model HTTP/1.1" 200
2026-02-27 14:49:29,643 - INFO - skyrl: Created model model_9dd5254a with
adapter_index=1, config=rank=32 alpha=32.0 seed=528014039 train_attn=True
train_mlp=True train_unembed=False
2026-02-27 14:49:29,644 - INFO - skyrl: Created LoRA model model_9dd5254a
2026-02-27 14:49:29,645 - INFO - skyrl: (timing)
process_single_request(create_model) took 8.118s
2026-02-27 14:49:29,643 - INFO - skyrl: Created model model_9dd5254a with
adapter_index=1, config=rank=32 alpha=32.0 seed=528014039 train_attn=True
train_mlp=True train_unembed=False
2026-02-27 14:49:29,645 - INFO - skyrl: Created model model_9dd5254a with
adapter_index=1, config=rank=32 alpha=32.0 seed=528014039 train_attn=True
train_mlp=True train_unembed=False
2026-02-27 14:49:29,648 - INFO - skyrl: Created model model_9dd5254a with
adapter_index=1, config=rank=32 alpha=32.0 seed=528014039 train_attn=True
train_mlp=True train_unembed=False

...
...
...
2026-02-27 14:49:33,094 - INFO - uvicorn.access: 127.0.0.1:57806 - "POST
/api/v1/forward_backward HTTP/1.1" 200
2026-02-27 14:49:33,151 - INFO - skyrl: JIT compiling for train seq_len=32 in
progress...
2026-02-27 14:49:33,151 - INFO - skyrl: JIT compiling for train seq_len=32 in
progress...
2026-02-27 14:49:33,151 - INFO - skyrl: JIT compiling for train seq_len=32 in
progress...
2026-02-27 14:49:33,151 - INFO - skyrl: JIT compiling for train seq_len=32 in
progress...
2026-02-27 14:49:38,374 - INFO - skyrl: JIT compilation for train seq_len=32
took 5.22s
2026-02-27 14:49:38,460 - INFO - skyrl: JIT compilation for train seq_len=32
took 5.31s
2026-02-27 14:49:38,515 - INFO - skyrl: JIT compilation for train seq_len=32
took 5.36s
2026-02-27 14:49:38,524 - INFO - skyrl: JIT compilation for train seq_len=32
took 5.37s
2026-02-27 14:49:38,585 - INFO - skyrl: (timing)
process_batch_requests(forward_backward, n=1) took 5.477s
2026-02-27 14:49:39,188 - INFO - uvicorn.access: 127.0.0.1:49100 - "POST
/api/v1/retrieve_future HTTP/1.1" 200
2026-02-27 14:49:39,196 - INFO - uvicorn.access: 127.0.0.1:39872 - "POST
/api/v1/optim_step HTTP/1.1" 200
2026-02-27 14:49:40,506 - INFO - skyrl: Applied optimizer step for model
model_9dd5254a (adapter 1), metrics={'skyrl.ai/grad_norm': 10.125,
'skyrl.ai/learning_rate': 0.00010013580322265625}
2026-02-27 14:49:40,506 - INFO - skyrl: Applied optimizer step for model
model_9dd5254a (adapter 1), metrics={'skyrl.ai/grad_norm': 10.125,
'skyrl.ai/learning_rate': 0.00010013580322265625}
2026-02-27 14:49:40,506 - INFO - skyrl: Applied optimizer step for model
model_9dd5254a (adapter 1), metrics={'skyrl.ai/grad_norm': 10.125,
'skyrl.ai/learning_rate': 0.00010013580322265625}
2026-02-27 14:49:40,507 - INFO - skyrl: (timing)
process_single_request(optim_step) took 1.308s

@pcmoritz pcmoritz changed the title [WIP] [tx] Fix tied embedding closure for TPUs [tx] Fix tied embedding closure for TPUs Mar 1, 2026
@pcmoritz pcmoritz merged commit 1c19403 into NovaSky-AI:main Mar 1, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants