Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
00f821e
[tx] Add SkyRL-train backend
pcmoritz Jan 13, 2026
8d3e210
update
pcmoritz Jan 13, 2026
1d4914d
build requirements
pcmoritz Jan 13, 2026
40b2fcf
update
pcmoritz Jan 13, 2026
b3f5d34
update
pcmoritz Jan 13, 2026
8d8c2f1
Merge branch 'main' into tx-skyrl-train-backend
pcmoritz Jan 24, 2026
75aa030
update
pcmoritz Jan 24, 2026
a677873
update
pcmoritz Jan 24, 2026
9415999
update
pcmoritz Jan 24, 2026
275cb57
update
pcmoritz Jan 24, 2026
534c0bb
update
pcmoritz Jan 24, 2026
ec2cc81
update
pcmoritz Jan 25, 2026
f540c8e
update
pcmoritz Jan 25, 2026
64d7b01
update
pcmoritz Jan 25, 2026
34da65c
update
pcmoritz Jan 25, 2026
f7557b4
update
pcmoritz Jan 25, 2026
877b3e5
update
pcmoritz Jan 25, 2026
4fe7a67
update
pcmoritz Jan 25, 2026
326280b
update
pcmoritz Jan 25, 2026
6242d9f
update
pcmoritz Jan 25, 2026
ce20462
update
pcmoritz Jan 25, 2026
1066fcf
update
pcmoritz Jan 25, 2026
9461891
black
pcmoritz Jan 25, 2026
9f7b31c
update
pcmoritz Jan 25, 2026
69d3db9
update
pcmoritz Jan 25, 2026
89ff6ba
Merge branch 'main' into tx-skyrl-train-backend
pcmoritz Jan 27, 2026
aeb15b4
update
pcmoritz Jan 27, 2026
2a96b3c
Merge branch 'tx-skyrl-train-backend' of https://github.com/pcmoritz/…
pcmoritz Jan 27, 2026
2eca5bf
update
pcmoritz Jan 27, 2026
234bb5d
update
pcmoritz Jan 27, 2026
695dbfb
fix CI
pcmoritz Jan 27, 2026
ad46cff
lint
pcmoritz Jan 27, 2026
5ea0d82
update
pcmoritz Jan 27, 2026
bdb250b
update
pcmoritz Jan 27, 2026
089204d
update
pcmoritz Jan 27, 2026
3cf9c78
update
pcmoritz Jan 27, 2026
171fbaa
update
pcmoritz Jan 27, 2026
6a436c9
update
pcmoritz Jan 27, 2026
fce28b6
address comments
pcmoritz Jan 27, 2026
d4f6f94
update
pcmoritz Jan 27, 2026
e217606
update
pcmoritz Jan 27, 2026
b9a7cc7
update
pcmoritz Jan 27, 2026
6d117dd
update config
pcmoritz Jan 27, 2026
4ab28a4
use latest
pcmoritz Jan 27, 2026
8eb3fca
updates
pcmoritz Jan 27, 2026
d1756a3
update
pcmoritz Jan 27, 2026
3ae71e5
update
pcmoritz Jan 27, 2026
717ac68
fix
pcmoritz Jan 27, 2026
756be47
update
pcmoritz Jan 27, 2026
17e3777
update
pcmoritz Jan 27, 2026
36f5bfa
update
pcmoritz Jan 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions skyrl-tx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ gpu = [
"jax[cuda12]>=0.7.2",
]

jax = [
"jax[cuda12]>=0.7.2",
]

tpu = [
"jax[tpu]>=0.7.2",
]
Expand All @@ -56,6 +60,10 @@ azure = [
"cloudpathlib[azure]",
]

skyrl_train = [
"skyrl-train[vllm]; python_version == '3.12'",
]

dev = [
"mkdocs",
"mkdocs-material",
Expand All @@ -75,3 +83,14 @@ version = {attr = "tx.__version__"}

[project.scripts]
tx = "tx.run.main:app"

[tool.uv.extra-build-dependencies]
flash-attn = [{requirement = "torch", match-runtime = true}]
transformer-engine = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]
transformer-engine-torch = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]

[tool.uv.extra-build-variables]
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}

[tool.uv.sources]
skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", rev = "3683cebef13e399cd02bcdb51b50e4e2e709e81c", subdirectory = "skyrl-train" }
2 changes: 1 addition & 1 deletion skyrl-tx/tx/tinker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def lifespan(app: FastAPI):
logger.info("Using internal engine for inference")

# Build subprocess command with engine config parameters
cmd = ["uv", "run", "--extra", "tinker", "-m", "tx.tinker.engine"]
cmd = ["uv", "run", "--extra", "tinker", "--extra", app.state.engine_config.backend, "-m", "tx.tinker.engine"]
cmd.extend(config_to_argv(app.state.engine_config))

background_engine = await asyncio.create_subprocess_exec(*cmd)
Expand Down
197 changes: 197 additions & 0 deletions skyrl-tx/tx/tinker/backends/skyrl_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""SkyRL-Train backend for TinkerEngine.

Uses SkyRL-Train infrastructure for supervised training with cross-entropy loss.
Currently supports a single model only.
"""

import ray
import torch
from pydantic import BaseModel
from ray.util.placement_group import placement_group

from tx.tinker import types
from tx.tinker.backends.backend import AbstractBackend
from tx.utils.log import logger

from skyrl_train.training_batch import TrainingInputBatch
from skyrl_train.workers.worker import PPORayActorGroup
from skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
from skyrl_train.utils import get_ray_pg_ready_with_timeout
from skyrl_train.config.utils import get_default_config


class SkyRLTrainBackendConfig(BaseModel, extra="forbid"):
"""Configuration for the SkyRL-Train backend."""

pass


def _build_config(base_model: str, config: SkyRLTrainBackendConfig, lora_config: types.LoraConfig | None = None):
"""Build config for SkyRL-Train workers using default config."""
cfg = get_default_config()
cfg.trainer.policy.model.path = base_model
return cfg


class SkyRLTrainBackend(AbstractBackend):
"""SkyRL-Train backend for supervised training."""

def __init__(self, base_model: str, config: SkyRLTrainBackendConfig):
logger.warning("=" * 80)
logger.warning("SkyRLTrainBackend is currently EXPERIMENTAL!")
logger.warning("=" * 80)

self.base_model = base_model
self.config = config
self._model_id: str | None = None
self._model_metadata: types.ModelMetadata | None = None
self._actor_group: PPORayActorGroup | None = None
self._dispatch: WorkerDispatch | None = None
self._cfg = None

if not ray.is_initialized():
ray.init(ignore_reinit_error=True)

def has_model(self, model_id: str) -> bool:
return self._model_id == model_id

def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
if self._model_id is not None:
raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.")

self._cfg = _build_config(self.base_model, self.config, lora_config)
num_gpus = self._cfg.trainer.placement.policy_num_gpus_per_node

pg = placement_group([{"GPU": num_gpus, "CPU": 4}], strategy="PACK")
get_ray_pg_ready_with_timeout(pg, timeout=30)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit - this is sometimes not long enough since installing dependencies on all ray workers the first time can be expensive

we have this env var:

from skyrl_train.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S

that we set to 180 by default


self._actor_group = PPORayActorGroup(
cfg=self._cfg,
num_nodes=1,
num_gpus_per_node=num_gpus,
ray_actor_type=PolicyWorker,
pg=pg,
num_gpus_per_actor=0.75 if num_gpus == 1 else 1.0,
colocate_all=False,
sequence_parallel_size=1,
)
ray.get(self._actor_group.async_init_model(self.base_model))
self._dispatch = WorkerDispatch(self._cfg, policy_actor_group=self._actor_group)

self._model_id = model_id
self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config)
logger.info(f"Created model {model_id}")

def delete_model(self, model_id: str) -> None:
if self._model_id != model_id:
raise ValueError(f"Model {model_id} not found")
self._dispatch = None
self._actor_group = None
self._model_id = None
self._model_metadata = None
self._cfg = None

def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> TrainingInputBatch:
"""Convert PreparedModelPassBatch to TrainingInputBatch."""
if not prepared_batch.all_input_ids:
return TrainingInputBatch({})

max_len = max(len(seq) for seq in prepared_batch.all_input_ids)
num_actions_per_example = [sum(1 for w in weights if w > 0) for weights in prepared_batch.all_token_weights]
max_num_actions = max(num_actions_per_example, default=0)

sequences, attention_masks, loss_masks = [], [], []

for seq, num_actions in zip(prepared_batch.all_input_ids, num_actions_per_example):
pad_len = max_len - len(seq)
sequences.append([0] * pad_len + list(seq))
attention_masks.append([0] * pad_len + [1] * len(seq))
action_pad = max_num_actions - num_actions
loss_masks.append([0] * action_pad + [1] * num_actions)

sequences_tensor = torch.tensor(sequences, dtype=torch.long)
attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long)
loss_mask_tensor = torch.tensor(loss_masks, dtype=torch.long)

batch = TrainingInputBatch(
{
"sequences": sequences_tensor,
"attention_mask": attention_mask_tensor,
"loss_mask": loss_mask_tensor,
}
Comment on lines +129 to +150
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a naming inconsistency for the mask. Here, you are creating a response_mask, but in skyrl-train/skyrl_train/workers/worker.py the code expects an action_mask. This inconsistency can be confusing. For better clarity and maintainability, I recommend using a consistent name across the components. Since worker.py is being updated to use action_mask, it seems to be the intended name.

This change would also require updating skyrl-train/training_batch.py to use action_mask in the TrainingInput TypedDict.

        sequences, attention_masks, loss_masks, action_masks = [], [], [], []

        for seq, weights in zip(full_sequences, prepared_batch.all_token_weights):
            pad_len = max_seq_len - len(seq)
            sequences.append([self._tokenizer.pad_token_id] * pad_len + list(seq))
            attention_masks.append([0] * pad_len + [1] * len(seq))
            action_pad = max_response_len - len(weights)
            loss_masks.append([0.0] * action_pad + [float(w) for w in weights])
            action_masks.append([0] * action_pad + [1] * len(weights))

        sequences_tensor = torch.tensor(sequences, dtype=torch.long)
        attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long)
        loss_mask_tensor = torch.tensor(loss_masks, dtype=torch.float32)
        action_mask_tensor = torch.tensor(action_masks, dtype=torch.long)

        batch = TrainingInputBatch(
            {
                "sequences": sequences_tensor,
                "attention_mask": attention_mask_tensor,
                "loss_mask": loss_mask_tensor,
                "action_mask": action_mask_tensor,
            }

)
batch.metadata = {"response_length": max_num_actions}
return batch

def forward_backward(
self,
prepared_batch: types.PreparedModelPassBatch,
loss_fn: str = "cross_entropy",
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}

batch = self._to_training_batch(prepared_batch)
data = self._dispatch.forward_backward("policy", batch, loss_fn=loss_fn)

results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
raw_output = data["loss_fn_outputs"][i]
# Convert raw lists to TensorData format expected by the API
logprobs = raw_output.get("logprobs", [])
elementwise_loss = raw_output.get("elementwise_loss", [])
seq_len = len(prepared_batch.all_input_ids[i])
# SkyRL-Train returns response-only outputs; align to full sequence length.
elementwise_loss = ([0.0] * max(seq_len - len(elementwise_loss), 0)) + list(elementwise_loss)[-seq_len:]
logprobs = ([0.0] * max(seq_len - len(logprobs), 0)) + list(logprobs)[-seq_len:]
loss_fn_outputs.append(
{
"elementwise_loss": {
"data": list(elementwise_loss),
"dtype": "float32",
"shape": [len(elementwise_loss)],
},
"logprobs": {
"data": list(logprobs),
"dtype": "float32",
"shape": [len(logprobs)],
},
}
)
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar",
loss_fn_outputs=loss_fn_outputs,
metrics={},
)
return results

def forward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
raise NotImplementedError("Forward-only pass not supported")

def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
grad_norm = self._dispatch.optim_step("policy")
logger.info(f"grad_norm: {grad_norm}")
return types.OptimStepOutput()

def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
raise NotImplementedError("Sampling not supported")

def save_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Saving checkpoints not supported")

def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
raise NotImplementedError("Loading checkpoints not supported")

def save_sampler_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Sampler checkpoints not supported")
2 changes: 2 additions & 0 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tx.tinker import types
from tx.tinker.config import EngineConfig, add_model
from tx.tinker.backends.jax import JaxBackend, JaxBackendConfig
from tx.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig
from tx.tinker.backends.utils import log_timing
from tx.tinker.loss_fns import LOSS_TYPES
from tx.utils.log import logger
Expand Down Expand Up @@ -131,6 +132,7 @@ def prepare_model_pass_batch(

BACKENDS = {
"jax": (JaxBackend, JaxBackendConfig),
"skyrl_train": (SkyRLTrainBackend, SkyRLTrainBackendConfig),
}


Expand Down
Loading