[tx] Fix tied embedding closure for TPUs#1233
Merged
pcmoritz merged 3 commits intoNovaSky-AI:mainfrom Mar 1, 2026
Merged
Conversation
Contributor
|
This PR works on my multi-host TPU cluster (4x4 v6e) on GKE. I used this script which starts coordinator node and workers: #!/bin/bash
LABEL_SELECTOR="ray.io/cluster=raycluster-tpu-v6e-multihost,ray.io/node-type=worker"
NAMESPACE="default"
PODS=$(kubectl get pods -l "$LABEL_SELECTOR" -n "$NAMESPACE" -o name)
if [ -z "$PODS" ]; then
echo " No pods found matching label: $LABEL_SELECTOR"
exit 1
fi
COORDINATOR_ADDRESS=""
echo ">> Searching for coordinator (TPU_WORKER_ID=0)..."
for POD in $PODS; do
WORKER_ID=$(kubectl exec -n "$NAMESPACE" "$POD" -- sh -c 'echo $TPU_WORKER_ID')
if [ "$WORKER_ID" == "0" ]; then
COORDINATOR_ADDRESS=$(kubectl -n "$NAMESPACE" get "$POD" --template '{{.status.podIP}}')
kubectl exec -n "$NAMESPACE" "$POD" -- \
uv run --extra tpu --extra tinker --extra jax -m skyrl.tinker.api \
--base-model "Qwen/Qwen3-0.6B" \
--backend-config "{\"train_micro_batch_size\": 8, \"sample_max_num_sequences\": 256, \"tensor_parallel_size\": 4, \"fully_sharded_data_parallel_size\": 4, \"num_processes\": 4, \"coordinator_address\": \"$COORDINATOR_ADDRESS:7777\"}" &
echo ">> Found Coordinator: $POD at $COORDINATOR_ADDRESS"
sleep 60
break
fi
done
if [ -z "$COORDINATOR_ADDRESS" ]; then
echo ">> Error: Could not find a pod with TPU_WORKER_ID=0. Exiting."
exit 1
fi
for POD in $PODS; do
PROCESS_ID=$(kubectl exec -n "$NAMESPACE" "$POD" -- sh -c 'echo $TPU_WORKER_ID')
if [ "$PROCESS_ID" != "0" ]; then
echo ">> Executing on $POD (Process ID: $PROCESS_ID)..."
kubectl exec -n "$NAMESPACE" "$POD" -- \
uv run --extra tpu --extra tinker --extra jax -m skyrl.backends.jax \
--coordinator-address "$COORDINATOR_ADDRESS:7777" \
--num-processes 4 \
--process-id "$PROCESS_ID" &
fi
echo "-----------------------------------------------"
doneThen I ran the rl_loop.py example from the README on node0 Some relavant logs from server: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes a problem discovered by @andrewsykim while working on #1024 -- the Jax TPU backend is more strict when closing over variables. The root cause is https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl/tx/layers/lora.py#L173 being closed over during (eager) model initialization time. Instead of closing over lm_head during initialization time, we do it when the model is evaluated, so it is properly JITted.
This is also a good fix for GPUs since it should reduce memory consumption.