-
Notifications
You must be signed in to change notification settings - Fork 286
[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy
#1296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
justinvyu
wants to merge
20
commits into
NovaSky-AI:main
Choose a base branch
from
justinvyu:token_mean_loss_reduction
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
589c150
Move loss reduction normalization to trainer-level advantage scaling,…
justinvyu 333f31a
Add token_mean_baseline loss reduction for mean-of-microbatch-means c…
justinvyu aaaba4c
fix assertion
justinvyu a121360
Update tests for sum-based reduce_loss and dp_size scaling changes
justinvyu 15de89a
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu e3842c3
lint
justinvyu 13bfe80
fix tests
justinvyu e76bece
Refactor advantage normalization: fix z-score propagation, skip for c…
justinvyu 0192e8e
token_mean_baseline -> token_mean_legacy
justinvyu 4ee0b31
Extract apply_loss_reduction_to_advantages_minibatch to ppo_utils and…
justinvyu c8f06cc
Fix metric reporting: remove dp_size scaling, separate micro-batch vs…
justinvyu 2c13315
Fix critic metric reporting: explicit sum_loss_metrics flag for reduc…
justinvyu 14ba02e
Remove reduce_metrics_across_minibatches, reuse reduce_metrics
justinvyu 0cfc95b
Merge remote-tracking branch 'upstream/main' into token_mean_loss_red…
justinvyu 717c3a7
add some comments about sum metrics
justinvyu 661f5d8
add clarifying comments and rename loss_scale
justinvyu 5cc95a1
no_grad for safety and make private
justinvyu ce8f6aa
remove outdated comments about loss reduction type in sapo tests
justinvyu 1a60bb5
fix test
justinvyu c5feb83
fix test
justinvyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -243,10 +243,19 @@ def loss_func(logits, data): | |
| loss_mask = data["loss_mask"] | ||
| rollout_action_logprobs = data["rollout_action_logprobs"] | ||
| action_mask = data.get("action_mask") | ||
| num_microbatches = data.get("num_microbatches") | ||
|
|
||
| dp_size = mpu.get_data_parallel_world_size() | ||
| tp_grp = mpu.get_tensor_model_parallel_group() | ||
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||
|
|
||
| # Megatron's pipeline parallel forward_backward_func internally divides loss by num_microbatches | ||
| # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248 | ||
| # we want to maintain a sum of losses across all micro batches, so we reverse this division. | ||
| # we additionally multiply by the data parallelism size to undo the DDP all-reduce mean | ||
| # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285 | ||
| loss_scale = num_microbatches * dp_size | ||
|
|
||
| # temperature normalization | ||
| if temperature != 1.0: | ||
| logits.div_(temperature) | ||
|
|
@@ -276,13 +285,15 @@ def loss_func(logits, data): | |
|
|
||
| # SFT path: cross_entropy loss (negative log likelihood) | ||
| if resolved_loss_name == "cross_entropy": | ||
| loss = policy_loss | ||
| unscaled_loss = policy_loss | ||
| loss = unscaled_loss * loss_scale | ||
|
|
||
| # Compute elementwise loss for Tinker API (per-token NLL) | ||
| with torch.no_grad(): | ||
| elementwise_loss = -action_log_probs | ||
| if loss_mask is not None: | ||
| elementwise_loss = elementwise_loss * loss_mask | ||
| elementwise_loss = elementwise_loss * loss_scale | ||
|
|
||
| # Build per-sequence loss_fn_outputs | ||
| batch_size = action_log_probs.shape[0] | ||
|
|
@@ -303,7 +314,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "loss": loss.detach().item(), | ||
| "loss": unscaled_loss.detach().item(), | ||
| "response_length": num_actions, | ||
| "loss_fn_outputs": loss_fn_outputs, | ||
| } | ||
|
|
@@ -333,7 +344,8 @@ def loss_func(logits, data): | |
| kl_loss = torch.tensor(0.0) | ||
| kl_loss_term = kl_loss * loss_config.kl_loss_coef | ||
|
|
||
| loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| loss = unscaled_loss * loss_scale | ||
|
|
||
| # Build per-sequence loss_fn_outputs with logprobs. | ||
| batch_size = action_log_probs.shape[0] | ||
|
|
@@ -356,7 +368,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "final_loss": loss.detach().item(), | ||
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), | ||
|
Comment on lines
+380
to
+381
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metrics fix 1: remove |
||
| "policy_entropy": entropy.detach().item(), | ||
| "policy_kl": kl_loss.detach().item(), | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.