Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skyrl/tinker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,11 @@ class ForwardBackwardInput(BaseModel):
"cross_entropy": set(),
"importance_sampling": set(),
"ppo": {"clip_low_threshold", "clip_high_threshold"},
"cispo": {"clip_low_threshold", "clip_high_threshold"},
}

data: list[Datum]
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo"]
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo"]
loss_fn_config: dict[str, float] | None = None

@model_validator(mode="after")
Expand Down
17 changes: 17 additions & 0 deletions skyrl/tinker/loss_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,28 @@ def ppo_loss(
return -safe_loss_mask(jnp.minimum(unclipped, clipped), loss_mask)


def cispo_loss(
target_logprobs: jax.Array,
loss_mask: jax.Array,
sampling_logprobs: jax.Array,
advantages: jax.Array,
loss_fn_config: LossFnConfig,
) -> jax.Array:
"CISPO clipped-ratio policy gradient loss."
prob_ratio = jnp.exp(target_logprobs - sampling_logprobs)
clip_low_threshold = loss_fn_config.clip_low_threshold
clip_high_threshold = loss_fn_config.clip_high_threshold
clipped_ratio = jnp.clip(prob_ratio, clip_low_threshold, clip_high_threshold)
cispo_objective = jax.lax.stop_gradient(clipped_ratio) * target_logprobs * advantages
return -safe_loss_mask(cispo_objective, loss_mask)


# Map from string names to loss functions
LOSS_FUNCTION_MAP = {
"cross_entropy": cross_entropy_loss,
"importance_sampling": importance_sampling_loss,
"ppo": ppo_loss,
"cispo": cispo_loss,
}

# Build list of functions indexed by LOSS_TYPES values (for jax.lax.switch)
Expand Down
3 changes: 2 additions & 1 deletion skyrl/tinker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Datum(BaseModel):

class ForwardBackwardInput(BaseModel):
data: list[Datum]
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo"]
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo"]
loss_fn_config: dict[str, float] | None = None


Expand Down Expand Up @@ -265,4 +265,5 @@ class PreparedSampleBatch(BaseModel):
"cross_entropy": 0,
"importance_sampling": 1,
"ppo": 2,
"cispo": 3,
}
61 changes: 36 additions & 25 deletions tests/tinker/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from cloudpathlib import AnyPath
from datetime import datetime, timedelta, timezone

import pytest
from sqlmodel import Session, SQLModel

from skyrl.tinker.engine import TinkerEngine, prepare_model_pass_batch
Expand Down Expand Up @@ -82,44 +83,54 @@ def test_cleanup_stale_sessions():
assert not engine.backend.has_model(model_id)


def test_prepare_model_pass_batch_loss_fn_config():
"""Test that prepare_model_pass_batch extracts loss_fn_config from requests."""
@pytest.mark.parametrize(
("loss_fn", "loss_fn_config", "advantages", "logprobs"),
[
pytest.param(
"ppo",
{"clip_low_threshold": 0.7, "clip_high_threshold": 1.3},
[],
[],
id="ppo_with_loss_fn_config",
),
pytest.param("cross_entropy", None, [], [], id="cross_entropy_default_config"),
pytest.param(
"cispo",
{"clip_low_threshold": 0.7, "clip_high_threshold": 1.3},
[0.1, 0.2, 0.3],
[-1.1, -1.0, -0.9],
id="cispo",
),
],
)
def test_prepare_model_pass_batch_loss_fn_and_config(
loss_fn: str,
loss_fn_config: dict[str, float] | None,
advantages: list[float],
logprobs: list[float],
):
"""Test that prepare_model_pass_batch preserves loss_fn and loss_fn_config values."""
datum = types.Datum(
model_input=types.ModelInput(chunks=[types.ModelInputChunk(tokens=[1, 2, 3])]),
loss_fn_inputs=types.LossFnInputs(
target_tokens=types.TensorData(data=[2, 3, 4]),
weights=types.TensorData(data=[1.0, 1.0, 1.0]),
advantages=types.TensorData(data=[]),
logprobs=types.TensorData(data=[]),
advantages=types.TensorData(data=advantages),
logprobs=types.TensorData(data=logprobs),
),
)
config = {"clip_low_threshold": 0.7, "clip_high_threshold": 1.3}

# With loss_fn_config
requests_with_config = {
requests = {
"req1": (
"model1",
types.ForwardBackwardInput(
data=[datum],
loss_fn="ppo",
loss_fn_config=config,
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
),
),
}
batch = prepare_model_pass_batch(requests_with_config)
assert batch.all_loss_fns == ["ppo"]
assert batch.all_loss_fn_configs == [config]

# Without loss_fn_config (default None)
requests_without_config = {
"req2": (
"model1",
types.ForwardBackwardInput(
data=[datum],
loss_fn="cross_entropy",
),
),
}
batch_no_config = prepare_model_pass_batch(requests_without_config)
assert batch_no_config.all_loss_fns == ["cross_entropy"]
assert batch_no_config.all_loss_fn_configs == [None]
batch = prepare_model_pass_batch(requests)
assert batch.all_loss_fns == [loss_fn]
assert batch.all_loss_fn_configs == [loss_fn_config]
Loading