[skyrl-train] Refactor TIS to use more comprehensive off policy correction config#849
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Truncated Importance Sampling (TIS) configuration into a more comprehensive rollout_correction system, which is a great improvement for structure and extensibility. The new implementation adds flexible rollout correction mechanisms, including different TIS ratio types and rejection masks. The changes are well-documented and handle the deprecation of old parameters gracefully. I've identified a bug in a conditional check that could cause a crash, and an opportunity to refactor for better efficiency and code clarity. My detailed feedback is in the comments below.
…kyRL into rollout_correction
… unite metrics under loss_metrics, other clean up
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant refactoring of the off-policy correction mechanism, replacing the simple TIS flags with a more comprehensive off_policy_correction configuration. This is a great improvement for flexibility and experimentation. The changes are well-implemented across the codebase, including documentation, examples, and tests. I've identified a few critical bugs in the implementation and some areas for improvement in the examples and utility functions to enhance clarity and correctness. Please see the detailed comments below.
skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_fp8.sh
Outdated
Show resolved
Hide resolved
skyrl-train/examples/flash_rl/run_dapo_gsm8k_flashrl_0.5b_int8.sh
Outdated
Show resolved
Hide resolved
|
/gemini review |
CharlieFRuan
left a comment
There was a problem hiding this comment.
Made an initial round of review. Will take another round!
| - ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages. | ||
|
|
||
| Off Policy Correction Configuration | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
There was a problem hiding this comment.
Let's cite the blogpost here as well
There was a problem hiding this comment.
Depends on how you add the separate correction doc page (see other comment). But it'd be easier for the user if we can do the following. Basically help the uesrs understand each config (3 groups of them) one-by-one by pointing them to other resources.
1. Group these three together
algorithm.off_policy_correction.tis_ratio_typealgorithm.off_policy_correction.token_tis_ratio_clip_highalgorithm.off_policy_correction.sequence_tis_ratio_clip_high
and tell them:
- how to do the basic TIS proposed in https://fengyao.notion.site/Your-Efficient-RL-Framework-Secretly-Brings-You-Off-Policy-RL-Training-237721e3f6c48094ad67dad3ac091c56
- i.e. token level, and a default clip value
- what is sequence level, the difference, and the motivation of doing that; perhaps simply refer to section 4.2 of https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda#27b211a558b78099ba48fa8849ab54c8
2. Then group these together
sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5
3. Then group the outlier threshold together
outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100
other remarks
Then pointing to our implementation would also be helpful. Namely the rollout_corrections.py or whatever name you decided in the end
| return y.to(out_dtype or delta.dtype) | ||
|
|
||
|
|
||
| def compute_tis_ratio( |
There was a problem hiding this comment.
These are great! Can we put them into a separate file? Our ppo_utils.py is 1.4k LOCs now.
In long term we could break ppo_utils.py down, but for now let's create a file of off_policy_correction_utils.py (or some other name you see fit) with all these methods you added. We can keep the rest in where they currently are and come back later if we'd want to further clean up.
| return 0.5 * loss, clipfrac | ||
|
|
||
|
|
||
| class LossMetrics(TypedDict, total=False): |
There was a problem hiding this comment.
this single-field TypedDict is a bit confusing. I know we can extend this with a lot of fields depending on the set up. Is there a better solution? Do we plan to add more fields to this? If not, should we remove this class for now?
There was a problem hiding this comment.
hmm you're right, let me just change the convention to return a python dictionary of metrics
|
|
||
| @register_policy_loss(PolicyLossType.REGULAR) | ||
| @register_policy_loss(PolicyLossType.DUAL_CLIP) | ||
| def ppo_policy_loss( |
There was a problem hiding this comment.
The return is typed as Tuple[torch.Tensor, float], which isn't correct right, due to it currently returning loss_metrics. Depending on what we do with LossMetrics as noted in the other comment on LossMetrics, we could make it dict[str, float]
| tis_imp_ratio = _safe_exp_delta(old_log_probs - rollout_logprobs, clip=20.0, out_dtype=log_probs.dtype) | ||
| tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) | ||
| loss = loss * tis_imp_ratio | ||
| # apply off policy correction |
There was a problem hiding this comment.
these seem redundant, used in sapo, gspo, cispo, and ppo. Can we write a functional helper to extract these out?
| return loss, loss_metrics | ||
|
|
||
|
|
||
| @register_policy_loss(PolicyLossType.SAPO) |
There was a problem hiding this comment.
Might be a dumb question: why is off-policy correction only used in sapo, gspo, cispo, and ppo, not the other loss functions like compute_policy_loss_clip_cov, compute_policy_loss_kl_cov?
There was a problem hiding this comment.
hmm it was because the covariance calculation could include masked out samples, so just adding the sequence masking before reduce_loss didn't seem sufficient.
There was a problem hiding this comment.
i think we could add it but i would vote to just skip for now since these are not commonly used anyway
| - ``tau_neg``: Temperature for gating function for tokens with negative (or zero) advantages. | ||
|
|
||
| Off Policy Correction Configuration | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
There was a problem hiding this comment.
Depends on how you add the separate correction doc page (see other comment). But it'd be easier for the user if we can do the following. Basically help the uesrs understand each config (3 groups of them) one-by-one by pointing them to other resources.
1. Group these three together
algorithm.off_policy_correction.tis_ratio_typealgorithm.off_policy_correction.token_tis_ratio_clip_highalgorithm.off_policy_correction.sequence_tis_ratio_clip_high
and tell them:
- how to do the basic TIS proposed in https://fengyao.notion.site/Your-Efficient-RL-Framework-Secretly-Brings-You-Off-Policy-RL-Training-237721e3f6c48094ad67dad3ac091c56
- i.e. token level, and a default clip value
- what is sequence level, the difference, and the motivation of doing that; perhaps simply refer to section 4.2 of https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda#27b211a558b78099ba48fa8849ab54c8
2. Then group these together
sequence_mask_metric: null # null, "product", "geometric"
geo_mask_high: 1.01
geo_mask_low: 0.99
product_mask_high: 2.0
product_mask_low: 0.5
3. Then group the outlier threshold together
outlier_token_is_threshold_low: 1e-4
outlier_token_is_threshold_high: 100
other remarks
Then pointing to our implementation would also be helpful. Namely the rollout_corrections.py or whatever name you decided in the end
CharlieFRuan
left a comment
There was a problem hiding this comment.
Such a neat PR!! Thank you so much!
Added some nits, after addressing them please feel free to merge!
| return pg_loss, {"clip_ratio": 0.0} | ||
|
|
||
|
|
||
| @register_policy_loss(PolicyLossType.CROSS_ENTROPY) |
There was a problem hiding this comment.
Should this return return loss, {"clip_ratio": 0.0} and change return type to Tuple[torch.Tensor, dict[str, float]]? Though we only use it in SFT, but it might make sense to keep it consistent
There was a problem hiding this comment.
This doc is just great! Thank you so much for the effort!
Some really nits:
- For a user that just wants to pick a correction config, they might not want to finish reading everything. Could we give some quick pointers (like TLDR, here are the configs you can start from) at the top? Like, just use
algorithm.off_policy_correction.tis_ratio_type=xxx
algorithm.off_policy_correction.token_tis_ratio_clip_high=xxx
if you want to do the most popular (or basic?) TIS.
Use xxx if you want to follow this blog, etc.
And could we add a reference section at the top or the bottom as well please
There was a problem hiding this comment.
And could we add a reference section at the top or the bottom as well please
already there! (just didn't screenshot)





Overview
trainer.algorithm.use_tisandtrainer.algorithm.tis_imp_ratio_capfor deprecationtrainer.algorithm.off_policy_correctionconfig (see new config below)LossMetricsTypedDict containing loss metrics (previously returned just loss, clip_ratio)Off Policy Correction Config