[skyrl-train] assert that the policy loss type is regular/dual clip for tis#546
[skyrl-train] assert that the policy loss type is regular/dual clip for tis#546erictang000 merged 2 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds a validation check to ensure that Truncated Importance Sampling (TIS) is only used with compatible policy loss types, which is a good safeguard. My feedback suggests a small improvement to make this validation more robust by using a ValueError instead of an assert, aligning with the existing validation patterns in the file.
| raise ValueError( | ||
| "Gneration with `trainer.algorithm.use_tis` needs to be batched with only single turn generation" | ||
| ) | ||
| assert cfg.trainer.algorithm.policy_loss_type in ["regular", "dual_clip"], "TIS is only implemented for regular and dual_clip policy loss types" |
There was a problem hiding this comment.
For configuration validation, it's better to raise a ValueError instead of using assert. assert statements can be disabled when Python is run with the -O (optimize) flag, which would cause this important validation to be skipped. Using ValueError ensures the check is always performed and is consistent with other checks in this function.
| assert cfg.trainer.algorithm.policy_loss_type in ["regular", "dual_clip"], "TIS is only implemented for regular and dual_clip policy loss types" | |
| if cfg.trainer.algorithm.policy_loss_type not in ["regular", "dual_clip"]: | |
| raise ValueError( | |
| f"TIS is only implemented for 'regular' and 'dual_clip' policy loss types, but got '{cfg.trainer.algorithm.policy_loss_type}'" | |
| ) |
…or tis (NovaSky-AI#546) TIS is currently only enabled for policy loss types that use the `ppo_policy_loss` code path
…or tis (NovaSky-AI#546) TIS is currently only enabled for policy loss types that use the `ppo_policy_loss` code path
…or tis (NovaSky-AI#546) TIS is currently only enabled for policy loss types that use the `ppo_policy_loss` code path
TIS is currently only enabled for policy loss types that use the
ppo_policy_losscode path