Skip to content

[skyrl-train] Refactor TIS to use more comprehensive off policy correction config#849

Merged
erictang000 merged 46 commits intoNovaSky-AI:mainfrom
erictang000:rollout_correction
Feb 3, 2026
Merged

[skyrl-train] Refactor TIS to use more comprehensive off policy correction config#849
erictang000 merged 46 commits intoNovaSky-AI:mainfrom
erictang000:rollout_correction

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Jan 7, 2026

Overview

  • Marks trainer.algorithm.use_tis and trainer.algorithm.tis_imp_ratio_cap for deprecation
  • Introduces new trainer.algorithm.off_policy_correction config (see new config below)
  • Updates loss functions to return a LossMetrics TypedDict containing loss metrics (previously returned just loss, clip_ratio)
  • Updates workers to all reduce mean/max/min appropriately, and to propagate loss metrics back up to the trainer.

Off Policy Correction Config

# To be deprecated in favor of off_policy_correction.tis_ratio_type = "token"
# and "token_tis_ratio_clip_high"
tis_imp_ratio_cap: -1.0
use_tis: false

off_policy_correction:
      # type of importance sampling ratio to use for ppo loss correction
      # here importance sampling ratio refers to exp(logprobs_{policy_old} - logprobs_{rollout_policy})
      tis_ratio_type: null # null, "token", "sequence"

      # used if tis_ratio_type = "token", 1.5-5.0 is recommended for "token" tis_ratio_type
      token_tis_ratio_clip_high: 2.0
      # used if tis_ratio_type = "sequence", 2.0-10.0 is recommended for "sequence" tis_ratio_type
      sequence_tis_ratio_clip_high: 5.0

      # method of masking out sequences with cumulative importance sampling ratios outside the cap
      # "product" masks out sequences with product of importance ratios outside the cap
      # "geometric" masks out sequences with geometric mean of importance ratios outside the cap
      sequence_mask_metric: null # null, "product", "geometric"

      # used if sequence_mask_metric = "geometric"
      # values around 0.99-1.01 are recommended for "geometric" sequence_mask_metric - MoE models may need larger allowed ranges due to higher mismatch
      geo_mask_high: 1.01
      geo_mask_low: 0.99

      # used if sequence_mask_metric = "product"
      # values around 0.5-2.0 are recommended for "sequence" sequence_mask_metric
      product_mask_high: 2.0
      product_mask_low: 0.5

      # separate from sequence_mask_metric and tis_ratio_type 
      # if any off_policy_correction is enabled, masks out sequences with any token having importance ratio
      # far outside an acceptable range (low and high thresholds)
      outlier_token_is_threshold_low: null
      outlier_token_is_threshold_high: null

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Truncated Importance Sampling (TIS) configuration into a more comprehensive rollout_correction system, which is a great improvement for structure and extensibility. The new implementation adds flexible rollout correction mechanisms, including different TIS ratio types and rejection masks. The changes are well-documented and handle the deprecation of old parameters gracefully. I've identified a bug in a conditional check that could cause a crash, and an opportunity to refactor for better efficiency and code clarity. My detailed feedback is in the comments below.

@erictang000 erictang000 changed the title [skyrl-train] Refactor TIS to use more comprehensive rollout correction config [skyrl-train] Refactor TIS to use more comprehensive off policy correction config Jan 8, 2026
@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant refactoring of the off-policy correction mechanism, replacing the simple TIS flags with a more comprehensive off_policy_correction configuration. This is a great improvement for flexibility and experimentation. The changes are well-implemented across the codebase, including documentation, examples, and tests. I've identified a few critical bugs in the implementation and some areas for improvement in the examples and utility functions to enhance clarity and correctness. Please see the detailed comments below.

@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Member

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made an initial round of review. Will take another round!

- ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages.

Off Policy Correction Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's cite the blogpost here as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on how you add the separate correction doc page (see other comment). But it'd be easier for the user if we can do the following. Basically help the uesrs understand each config (3 groups of them) one-by-one by pointing them to other resources.

1. Group these three together

  • algorithm.off_policy_correction.tis_ratio_type
  • algorithm.off_policy_correction.token_tis_ratio_clip_high
  • algorithm.off_policy_correction.sequence_tis_ratio_clip_high

and tell them:

2. Then group these together

sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5

3. Then group the outlier threshold together

outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

other remarks

Then pointing to our implementation would also be helpful. Namely the rollout_corrections.py or whatever name you decided in the end

return y.to(out_dtype or delta.dtype)


def compute_tis_ratio(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great! Can we put them into a separate file? Our ppo_utils.py is 1.4k LOCs now.

In long term we could break ppo_utils.py down, but for now let's create a file of off_policy_correction_utils.py (or some other name you see fit) with all these methods you added. We can keep the rest in where they currently are and come back later if we'd want to further clean up.

return 0.5 * loss, clipfrac


class LossMetrics(TypedDict, total=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this single-field TypedDict is a bit confusing. I know we can extend this with a lot of fields depending on the set up. Is there a better solution? Do we plan to add more fields to this? If not, should we remove this class for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm you're right, let me just change the convention to return a python dictionary of metrics


@register_policy_loss(PolicyLossType.REGULAR)
@register_policy_loss(PolicyLossType.DUAL_CLIP)
def ppo_policy_loss(
Copy link
Member

@CharlieFRuan CharlieFRuan Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return is typed as Tuple[torch.Tensor, float], which isn't correct right, due to it currently returning loss_metrics. Depending on what we do with LossMetrics as noted in the other comment on LossMetrics, we could make it dict[str, float]

tis_imp_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype)
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
loss = loss * tis_imp_ratio
# apply off policy correction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these seem redundant, used in sapo, gspo, cispo, and ppo. Can we write a functional helper to extract these out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

return loss, loss_metrics


@register_policy_loss(PolicyLossType.SAPO)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a dumb question: why is off-policy correction only used in sapo, gspo, cispo, and ppo, not the other loss functions like compute_policy_loss_clip_cov, compute_policy_loss_kl_cov?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it was because the covariance calculation could include masked out samples, so just adding the sequence masking before reduce_loss didn't seem sufficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we could add it but i would vote to just skip for now since these are not commonly used anyway

@CharlieFRuan CharlieFRuan self-assigned this Jan 22, 2026
- ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages.

Off Policy Correction Configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on how you add the separate correction doc page (see other comment). But it'd be easier for the user if we can do the following. Basically help the uesrs understand each config (3 groups of them) one-by-one by pointing them to other resources.

1. Group these three together

  • algorithm.off_policy_correction.tis_ratio_type
  • algorithm.off_policy_correction.token_tis_ratio_clip_high
  • algorithm.off_policy_correction.sequence_tis_ratio_clip_high

and tell them:

2. Then group these together

sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5

3. Then group the outlier threshold together

outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100

other remarks

Then pointing to our implementation would also be helpful. Namely the rollout_corrections.py or whatever name you decided in the end

@erictang000
Copy link
Collaborator Author

image image image

@erictang000
Copy link
Collaborator Author

image

Copy link
Member

@CharlieFRuan CharlieFRuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such a neat PR!! Thank you so much!

Added some nits, after addressing them please feel free to merge!

return pg_loss, {"clip_ratio": 0.0}


@register_policy_loss(PolicyLossType.CROSS_ENTROPY)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return return loss, {"clip_ratio": 0.0} and change return type to Tuple[torch.Tensor, dict[str, float]]? Though we only use it in SFT, but it might make sense to keep it consistent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is just great! Thank you so much for the effort!

Some really nits:

  • For a user that just wants to pick a correction config, they might not want to finish reading everything. Could we give some quick pointers (like TLDR, here are the configs you can start from) at the top? Like, just use
algorithm.off_policy_correction.tis_ratio_type=xxx
algorithm.off_policy_correction.token_tis_ratio_clip_high=xxx

if you want to do the most popular (or basic?) TIS.

Use xxx if you want to follow this blog, etc.

And could we add a reference section at the top or the bottom as well please

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And could we add a reference section at the top or the bottom as well please

already there! (just didn't screenshot)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a user that just wants to pick a correction config, they might not want to finish reading everything. Could we give some quick pointers (like TLDR, here are the configs you can start from) at the top? Like, just use

good point, added!
image

@erictang000 erictang000 merged commit 5102468 into NovaSky-AI:main Feb 3, 2026
3 checks passed
@erictang000 erictang000 deleted the rollout_correction branch February 3, 2026 20:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants