Skip to content

[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296

Open
justinvyu wants to merge 11 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction
Open

[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
justinvyu wants to merge 11 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Mar 9, 2026

Summary

  • Change reduce_loss() to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.
    • Scales the loss by the DP size before calling backward() to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.
  • Fixes the token_mean loss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with the token_mean_legacy config.

Loss reduction strategies

  • Option 1: token_mean

    • Average loss per token across the entire mini-batch.
    • This is the fixed version where the denominator is the total token count across the full mini-batch, so the gradient is independent of how the minibbatch is split into micro-batches.
  • Option 1b: token_mean_legacy

    • Compute token-mean loss within each micro-batch, then average across micro-batches.
    • This reproduces the token_mean behavior before this PR.
    • The problem: if micro-batches have different token counts, the effective weighting differs from a true global token mean. This is also less usable since changing micro batch size affects the loss and the training dynamics.
    • Kept as a fallback in case of performance regressions — we should remove this down the line.
  • Option 2: sequence_mean

    • Compute per-token loss within each sequence, average across sequences.
    • This is unchanged and is just implemented via advantage normalization instead.
  • Option 3: seq_mean_token_sum_norm

    • Dr. GRPO style — normalize by a fixed constant to avoid any length-dependent weighting.
    • This is unchanged and is just implemented via advantage normalization instead.

Mean all-reduce -> sum all-reduce

We need the loss to be summed across microbatches and data parallel workers:

  • DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size in order to keep the loss sum across data parallel groups.
  • Megatron also does a similar mean reduction across microbatches and workers, so we counteract this by multiplying by num microbatches and DP size to achieve the sum.

Tinker compatibility

Here was the first attempt at fixing the loss reduction across microbatches: #909

This method was to track total tokens and then do one big normalization at the optim_step in order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.

The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque forward_backward, optim_step implementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:

client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible

The current PR aligns with Tinker semantics:

Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.

Example for loss_reduction="token_mean":

  • Move the 1/num_minibatch_tokens normalization into the advantage: loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokens
  • -> sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )

Learning curve comparisons before/after the PR

FSDP (wandb)

Screenshot 2026-03-20 at 3 29 10 PM

Megatron (wandb)

1.7B:

Screenshot 2026-03-20 at 3 29 40 PM

30B lora:

master baseline:
Screenshot 2026-03-20 at 3 33 29 PM

token_mean_legacy + fixed token_mean:

Screenshot 2026-03-24 at 11 16 56 AM
Open with Devin

justinvyu and others added 3 commits March 9, 2026 11:51
… scale loss by dp_size for FSDP/Megatron parity

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
justinvyu and others added 7 commits March 9, 2026 18:27
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction

# Conflicts:
#	skyrl/backends/skyrl_train/utils/ppo_utils.py
#	skyrl/train/fully_async_trainer.py
#	skyrl/train/trainer.py
#	tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…ritic, rename token_mean_baseline to token_mean_legacy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@justinvyu justinvyu marked this pull request as ready for review March 20, 2026 22:34
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant refactoring of the loss reduction mechanism. The core change moves the reduction logic from a monolithic reduce_loss function to pre-scaling advantages based on the desired reduction strategy. This aligns with Tinker's API, making the system more modular and explicit. The changes also correctly handle gradient accumulation in distributed settings (DDP/FSDP/Megatron) by scaling the loss to counteract the default mean-reduction of gradients, effectively performing a sum. The refactoring is extensive, touching core training logic, worker implementations, and associated tests. My review identified a couple of issues related to metric reporting where losses are incorrectly scaled, which could lead to misleading monitoring data. Apart from that, the changes appear solid and well-implemented.

Comment on lines +371 to +372
"final_loss": unscaled_loss.detach().item() * dp_size,
"policy_loss": policy_loss.detach().item() * dp_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reported final_loss and policy_loss metrics are being scaled by dp_size. Since these metrics are summed across micro-batches and then sum-reduced across data-parallel ranks, this will result in the total loss being over-reported by a factor of dp_size. The loss scaling is necessary for correct gradient computation, but for metric reporting, the unscaled loss should be used to reflect the true total loss.

Suggested change
"final_loss": unscaled_loss.detach().item() * dp_size,
"policy_loss": policy_loss.detach().item() * dp_size,
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinvyu i think i remember i made these changes to get these metrics matching and to be invariant to dp size... but just want to check - were the metric scales roughly similar for FSDP vs megatron on your runs?

for the 1.7b runs could you paste these metrics for megatron vs fsdp?

Comment on lines +931 to +932
"final_loss": loss.item(),
"policy_loss": policy_loss.item(),
"policy_loss": policy_loss.item() * loss_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reported final_loss and policy_loss metrics are being scaled by loss_scale (which is dp_size). Since these metrics are summed across micro-batches and then sum-reduced across data-parallel ranks, this will result in the total loss being over-reported by a factor of dp_size. While loss scaling is correct for the backward pass to counteract DDP's mean reduction, the reported metrics should be based on the unscaled loss to accurately reflect the total loss.

Suggested change
"final_loss": loss.item(),
"policy_loss": policy_loss.item(),
"policy_loss": policy_loss.item() * loss_scale,
"final_loss": unscaled_loss.item(),
"policy_loss": policy_loss.item(),

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

@justinvyu justinvyu changed the title [wip] loss reduction [skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy Mar 20, 2026
Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks almost good to merge, super clean thanks for adding the token_mean_legacy path

just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...

Comment on lines +371 to +372
"final_loss": unscaled_loss.detach().item() * dp_size,
"policy_loss": policy_loss.detach().item() * dp_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinvyu i think i remember i made these changes to get these metrics matching and to be invariant to dp size... but just want to check - were the metric scales roughly similar for FSDP vs megatron on your runs?

for the 1.7b runs could you paste these metrics for megatron vs fsdp?

… mini-batch reduction

- Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers
- Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation)
- Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging)
- Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +371 to +372
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Copy link
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

# pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
all_metrics.pop("loss_fn_outputs", None)
reduced_metrics = reduce_metrics(all_metrics)
reduced_metrics = reduce_metrics_across_minibatches(all_metrics)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches

all_metrics[k].append(v)

return reduce_metrics(dict(all_metrics))
return reduce_metrics_across_microbatches(dict(all_metrics))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critic codepath may need to be reverted since it doesn't use the advantages?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants