Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
589c150
Move loss reduction normalization to trainer-level advantage scaling,…
justinvyu Mar 9, 2026
333f31a
Add token_mean_baseline loss reduction for mean-of-microbatch-means c…
justinvyu Mar 9, 2026
aaaba4c
fix assertion
justinvyu Mar 9, 2026
a121360
Update tests for sum-based reduce_loss and dp_size scaling changes
justinvyu Mar 10, 2026
15de89a
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 17, 2026
e3842c3
lint
justinvyu Mar 17, 2026
13bfe80
fix tests
justinvyu Mar 17, 2026
e76bece
Refactor advantage normalization: fix z-score propagation, skip for c…
justinvyu Mar 20, 2026
0192e8e
token_mean_baseline -> token_mean_legacy
justinvyu Mar 20, 2026
4ee0b31
Extract apply_loss_reduction_to_advantages_minibatch to ppo_utils and…
justinvyu Mar 20, 2026
c8f06cc
Fix metric reporting: remove dp_size scaling, separate micro-batch vs…
justinvyu Mar 25, 2026
2c13315
Fix critic metric reporting: explicit sum_loss_metrics flag for reduc…
justinvyu Mar 27, 2026
14ba02e
Remove reduce_metrics_across_minibatches, reuse reduce_metrics
justinvyu Mar 27, 2026
0cfc95b
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu Mar 27, 2026
717c3a7
add some comments about sum metrics
justinvyu Mar 27, 2026
661f5d8
add clarifying comments and rename loss_scale
justinvyu Mar 27, 2026
5cc95a1
no_grad for safety and make private
justinvyu Mar 27, 2026
ce8f6aa
remove outdated comments about loss reduction type in sapo tests
justinvyu Mar 27, 2026
1a60bb5
fix test
justinvyu Mar 27, 2026
c5feb83
fix test
justinvyu Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/train/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from skyrl.train.trainer import RayPPOTrainer
from tqdm import tqdm
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import normalize_advantages_dict
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.train.generators.base import GeneratorOutput
from skyrl.train.utils.trainer_utils import ResumeMode
Expand Down Expand Up @@ -146,9 +145,6 @@ async def _run_training(self, generation_buffer):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
4 changes: 0 additions & 4 deletions skyrl-agent/skyrl_agent/integrations/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from skyrl.train.utils import Timer
from skyrl.backends.skyrl_train.utils.ppo_utils import (
get_kl_controller,
normalize_advantages_dict,
)
from skyrl.train.utils.trainer_utils import (
validate_generator_output,
Expand Down Expand Up @@ -381,9 +380,6 @@ async def train(self):
training_input.pop(key)
training_input.metadata.pop("uids")

if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)

if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
Expand Down
17 changes: 9 additions & 8 deletions skyrl/backends/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def get_rank(self) -> int:
"""Get current process rank"""
return dist.get_rank()

def all_reduce(self, data: DataT, op="mean") -> DataT:
"""Perform all_reduce across all processes"""
def all_reduce(self, data: DataT, op="mean", group=None) -> DataT:
"""Perform all_reduce across all processes (or within a process group)."""
assert op in ("mean", "max", "sum", "min")
if isinstance(data, dict):
return {k: self.all_reduce(v, op) for k, v in data.items()}
return {k: self.all_reduce(v, op, group=group) for k, v in data.items()}
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
Expand All @@ -82,14 +82,15 @@ def all_reduce(self, data: DataT, op="mean") -> DataT:
if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.SUM)
group_size = dist.get_world_size(group) if group is not None else self.world_size
data /= group_size
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
elif op == "max":
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dist.all_reduce(data, op=dist.ReduceOp.MAX, group=group)
elif op == "min":
dist.all_reduce(data, op=dist.ReduceOp.MIN)
dist.all_reduce(data, op=dist.ReduceOp.MIN, group=group)
elif op == "sum":
dist.all_reduce(data, op=dist.ReduceOp.SUM)
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=group)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data
Expand Down
132 changes: 57 additions & 75 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
from collections import defaultdict
from enum import StrEnum
from functools import wraps
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import ray
import torch
from jaxtyping import Float
from loguru import logger

from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.utils.off_policy_correction_utils import (
apply_off_policy_correction,
)
Expand Down Expand Up @@ -125,27 +124,6 @@ def compute_approx_kl(
return kld


@torch.no_grad()
def normalize_advantages_dict(data: TrainingInputBatch) -> TrainingInputBatch:
"""Normalizes the advantages in the data batch.

Expects:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
advantages: Float[torch.Tensor, "batch_size seqlen"] = data["advantages"]
response_masks: Float[torch.Tensor, "batch_size seqlen"] = data["response_mask"]
num_actions: float = response_masks.sum()
# mean
mean: float = advantages.mean()
# std
std: float = ((advantages - mean).pow(2) * response_masks).sum()
rstd: float = (std / num_actions).clamp(min=1e-8).rsqrt()

data["advantages"] = (advantages - mean) * rstd
return data


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
Expand Down Expand Up @@ -559,12 +537,6 @@ def ppo_policy_loss(
rollout_logprobs: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict[str, float]]:
assert config.policy_loss_type in ["regular", "dual_clip"], "loss_type must be either 'regular' or 'dual_clip'"
loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
Expand All @@ -585,7 +557,7 @@ def ppo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -658,7 +630,7 @@ def gate_function(x, tau):
loss_metrics.update(off_policy_metrics)

# for SAPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -727,7 +699,7 @@ def gspo_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)

return loss, loss_metrics

Expand Down Expand Up @@ -764,7 +736,7 @@ def compute_policy_loss_cispo(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, config.loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -792,13 +764,6 @@ def rollout_is_policy_loss(
"""
assert rollout_logprobs is not None, "rollout_logprobs are required for rollout_is"

loss_reduction = config.loss_reduction
assert loss_reduction in [
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = safe_exp_delta(log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype)

in_range = (ratio > 1 - config.eps_clip_low) & (ratio < 1 + config.eps_clip_high)
Expand All @@ -813,7 +778,7 @@ def rollout_is_policy_loss(
)
loss_metrics.update(off_policy_metrics)

loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
loss = reduce_loss(loss, loss_mask)
return loss, loss_metrics


Expand Down Expand Up @@ -875,12 +840,7 @@ def compute_policy_loss_clip_cov(
# Apply correction mask to losses
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

return pg_loss, {"clip_ratio": clip_frac.item()}

Expand Down Expand Up @@ -934,12 +894,7 @@ def compute_policy_loss_kl_cov(
large_cov_idxs % advantages.shape[1],
]

pg_loss = reduce_loss(
loss=pg_losses,
loss_mask=loss_mask,
loss_reduction=config.loss_reduction,
max_seq_len=config.max_seq_len,
)
pg_loss = reduce_loss(loss=pg_losses, loss_mask=loss_mask)

# NOTE (sumanthrh): Since the pg clip ratio is not applicable for KL-COV so we just use 0.0
return pg_loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -978,10 +933,7 @@ def cross_entropy_loss(
elementwise_loss = -log_probs

# Apply loss mask and sum (matching Tinker's SUM reduction semantics)
if loss_mask is not None:
loss = (elementwise_loss * loss_mask).sum()
else:
loss = elementwise_loss.sum()
loss = reduce_loss(elementwise_loss, loss_mask)

# No clipping in cross-entropy loss
return loss, {"clip_ratio": 0.0}
Expand Down Expand Up @@ -1040,30 +992,60 @@ def importance_sampling_loss(
def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
return (loss * loss_mask).sum() if loss_mask is not None else loss.sum()


def apply_loss_reduction_to_advantages_minibatch(
advantages: torch.Tensor,
loss_mask: torch.Tensor,
loss_reduction: str,
micro_batch_size: int,
max_seq_len: int,
) -> torch.Tensor:
"""Scale advantages so that summing produces the desired loss reduction.

Args:
advantages: Advantage tensor of shape (minibatch_size, seq_len).
loss_mask: Mask of shape (minibatch_size, seq_len) indicating valid loss tokens.
loss_reduction: One of "token_mean", "token_mean_legacy", "sequence_mean", "seq_mean_token_sum_norm".
micro_batch_size: Number of sequences per micro-batch
max_seq_len: Maximum sequence length.

Returns:
Scaled advantages tensor.
"""
batch_size = advantages.shape[0]
normalized_advantages = torch.zeros_like(advantages)

# Option 1: token mean
if loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
normalized_advantages = advantages / loss_mask.sum().clamp(min=1)

# Option 1b: legacy token-mean that normalizes per-microbatch then averages across microbatches.
elif loss_reduction == "token_mean_legacy":
num_micro_batches = batch_size // micro_batch_size
for i in range(num_micro_batches):
start_idx = i * micro_batch_size
end_idx = (i + 1) * micro_batch_size
mb_advantages = advantages[start_idx:end_idx]
mb_loss_mask = loss_mask[start_idx:end_idx]
mb_advantages = mb_advantages / mb_loss_mask.sum().clamp(min=1)
mb_advantages /= num_micro_batches
normalized_advantages[start_idx:end_idx] = mb_advantages

# Option 2: sequence mean
elif loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
normalized_advantages = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1))

# Option 3: Dr. GRPO style loss reduction to avoid length bias by normalizing by a constant
elif loss_reduction == "seq_mean_token_sum_norm":
# per-sequence token-sum, normalized by the max sequence length, then batch mean
# this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant
assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction"
# NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to
# cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
if loss_mask is not None:
seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len
else:
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
normalized_advantages = advantages / (batch_size * max_seq_len)

else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss

return normalized_advantages


# NOTE (erictang000): below ported from verl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,19 @@ def loss_func(logits, data):
loss_mask = data["loss_mask"]
rollout_action_logprobs = data["rollout_action_logprobs"]
action_mask = data.get("action_mask")
num_microbatches = data.get("num_microbatches")

dp_size = mpu.get_data_parallel_world_size()
tp_grp = mpu.get_tensor_model_parallel_group()
tp_rank = mpu.get_tensor_model_parallel_rank()

# Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches
# https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248
# we want to maintain a sum of losses across all micro batches, so we reverse this division.
# we additionally multiply by the data parallelism size to undo the DDP all-reduce mean
# https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285
loss_scale = num_microbatches * dp_size

# temperature normalization
if temperature != 1.0:
logits.div_(temperature)
Expand Down Expand Up @@ -276,13 +285,15 @@ def loss_func(logits, data):

# SFT path: cross_entropy loss (negative log likelihood)
if resolved_loss_name == "cross_entropy":
loss = policy_loss
unscaled_loss = policy_loss
loss = unscaled_loss * loss_scale

# Compute elementwise loss for Tinker API (per-token NLL)
with torch.no_grad():
elementwise_loss = -action_log_probs
if loss_mask is not None:
elementwise_loss = elementwise_loss * loss_mask
elementwise_loss = elementwise_loss * loss_scale

# Build per-sequence loss_fn_outputs
batch_size = action_log_probs.shape[0]
Expand All @@ -303,7 +314,7 @@ def loss_func(logits, data):
)

metrics = {
"loss": loss.detach().item(),
"loss": unscaled_loss.detach().item(),
"response_length": num_actions,
"loss_fn_outputs": loss_fn_outputs,
}
Expand Down Expand Up @@ -333,7 +344,8 @@ def loss_func(logits, data):
kl_loss = torch.tensor(0.0)
kl_loss_term = kl_loss * loss_config.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term
unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term
loss = unscaled_loss * loss_scale

# Build per-sequence loss_fn_outputs with logprobs.
batch_size = action_log_probs.shape[0]
Expand All @@ -356,7 +368,7 @@ def loss_func(logits, data):
)

metrics = {
"final_loss": loss.detach().item(),
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Comment on lines +380 to +381
Copy link
Copy Markdown
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

"policy_entropy": entropy.detach().item(),
"policy_kl": kl_loss.detach().item(),
Expand Down
Loading