[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296justinvyu wants to merge 11 commits intoNovaSky-AI:mainfrom
token_mean reduction strategy#1296Conversation
… scale loss by dp_size for FSDP/Megatron parity Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction # Conflicts: # skyrl/backends/skyrl_train/utils/ppo_utils.py # skyrl/train/fully_async_trainer.py # skyrl/train/trainer.py # tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
…ritic, rename token_mean_baseline to token_mean_legacy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a significant refactoring of the loss reduction mechanism. The core change moves the reduction logic from a monolithic reduce_loss function to pre-scaling advantages based on the desired reduction strategy. This aligns with Tinker's API, making the system more modular and explicit. The changes also correctly handle gradient accumulation in distributed settings (DDP/FSDP/Megatron) by scaling the loss to counteract the default mean-reduction of gradients, effectively performing a sum. The refactoring is extensive, touching core training logic, worker implementations, and associated tests. My review identified a couple of issues related to metric reporting where losses are incorrectly scaled, which could lead to misleading monitoring data. Apart from that, the changes appear solid and well-implemented.
| "final_loss": unscaled_loss.detach().item() * dp_size, | ||
| "policy_loss": policy_loss.detach().item() * dp_size, |
There was a problem hiding this comment.
The reported final_loss and policy_loss metrics are being scaled by dp_size. Since these metrics are summed across micro-batches and then sum-reduced across data-parallel ranks, this will result in the total loss being over-reported by a factor of dp_size. The loss scaling is necessary for correct gradient computation, but for metric reporting, the unscaled loss should be used to reflect the true total loss.
| "final_loss": unscaled_loss.detach().item() * dp_size, | |
| "policy_loss": policy_loss.detach().item() * dp_size, | |
| "final_loss": unscaled_loss.detach().item(), | |
| "policy_loss": policy_loss.detach().item(), |
There was a problem hiding this comment.
@justinvyu i think i remember i made these changes to get these metrics matching and to be invariant to dp size... but just want to check - were the metric scales roughly similar for FSDP vs megatron on your runs?
for the 1.7b runs could you paste these metrics for megatron vs fsdp?
| "final_loss": loss.item(), | ||
| "policy_loss": policy_loss.item(), | ||
| "policy_loss": policy_loss.item() * loss_scale, |
There was a problem hiding this comment.
The reported final_loss and policy_loss metrics are being scaled by loss_scale (which is dp_size). Since these metrics are summed across micro-batches and then sum-reduced across data-parallel ranks, this will result in the total loss being over-reported by a factor of dp_size. While loss scaling is correct for the backward pass to counteract DDP's mean reduction, the reported metrics should be based on the unscaled loss to accurately reflect the total loss.
| "final_loss": loss.item(), | |
| "policy_loss": policy_loss.item(), | |
| "policy_loss": policy_loss.item() * loss_scale, | |
| "final_loss": unscaled_loss.item(), | |
| "policy_loss": policy_loss.item(), |
token_mean reduction strategy
erictang000
left a comment
There was a problem hiding this comment.
this looks almost good to merge, super clean thanks for adding the token_mean_legacy path
just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...
| "final_loss": unscaled_loss.detach().item() * dp_size, | ||
| "policy_loss": policy_loss.detach().item() * dp_size, |
There was a problem hiding this comment.
@justinvyu i think i remember i made these changes to get these metrics matching and to be invariant to dp size... but just want to check - were the metric scales roughly similar for FSDP vs megatron on your runs?
for the 1.7b runs could you paste these metrics for megatron vs fsdp?
… mini-batch reduction - Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers - Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation) - Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging) - Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), |
There was a problem hiding this comment.
Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.
| # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it | ||
| all_metrics.pop("loss_fn_outputs", None) | ||
| reduced_metrics = reduce_metrics(all_metrics) | ||
| reduced_metrics = reduce_metrics_across_minibatches(all_metrics) |
There was a problem hiding this comment.
Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches
| all_metrics[k].append(v) | ||
|
|
||
| return reduce_metrics(dict(all_metrics)) | ||
| return reduce_metrics_across_microbatches(dict(all_metrics)) |
There was a problem hiding this comment.
critic codepath may need to be reverted since it doesn't use the advantages?
Summary
reduce_loss()to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.backward()to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.token_meanloss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with thetoken_mean_legacyconfig.Loss reduction strategies
Option 1: token_mean
Option 1b: token_mean_legacy
token_meanbehavior before this PR.Option 2: sequence_mean
Option 3: seq_mean_token_sum_norm
Mean all-reduce -> sum all-reduce
We need the loss to be summed across microbatches and data parallel workers:
Tinker compatibility
Here was the first attempt at fixing the loss reduction across microbatches: #909
This method was to track total tokens and then do one big normalization at the
optim_stepin order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque
forward_backward,optim_stepimplementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:The current PR aligns with Tinker semantics:
Example for
loss_reduction="token_mean":1/num_minibatch_tokensnormalization into the advantage:loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokenssum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )Learning curve comparisons before/after the PR
FSDP (wandb)
Megatron (wandb)
1.7B:
30B lora:
master baseline:

token_mean_legacy+ fixedtoken_mean: