[tx] Fix loss function config keys and add validation#1159
[tx] Fix loss function config keys and add validation#1159pcmoritz merged 4 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request correctly refactors the loss function configuration keys, replacing the ambiguous clip_ratio with the more explicit clip_low_threshold and clip_high_threshold. The changes are consistently applied across the backend implementation, loss functions, and tests. This significantly improves the clarity and maintainability of the configuration handling, and as the description notes, brings it in line with the documentation. The test updates are thorough and correctly validate the new behavior. This is a high-quality and beneficial change.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request correctly updates the loss function configuration keys for PPO, replacing clip_ratio with clip_low_threshold and clip_high_threshold. The changes are applied consistently across the API, backend, and tests. The addition of a Pydantic validator to enforce correct loss_fn_config keys is a great improvement for API robustness. The new tests effectively cover the changes. I have one suggestion to improve maintainability.
skyrl-tx/tx/tinker/api.py
Outdated
| allowed_keys_by_loss_fn = { | ||
| "cross_entropy": set(), | ||
| "importance_sampling": set(), | ||
| "ppo": {"clip_low_threshold", "clip_high_threshold"}, | ||
| } |
There was a problem hiding this comment.
The allowed_keys_by_loss_fn dictionary is defined inside the validate_loss_fn_config_keys method. This means it will be recreated every time an instance of ForwardBackwardInput is validated. For better performance and to follow the convention of defining constants at the class or module level, consider moving this dictionary to be a class attribute of ForwardBackwardInput (e.g., _ALLOWED_KEYS_BY_LOSS_FN). This will avoid re-creating the dictionary on each validation call.
|
/gemini review |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request successfully refactors the loss function configuration by replacing the ambiguous clip_ratio with explicit clip_low_threshold and clip_high_threshold keys for the PPO loss function. It also introduces robust validation at the API level to ensure only valid configuration keys are used for each loss function, which is a great improvement for API clarity and correctness. The changes are consistently applied across the API, backend, and loss function implementations, and are well-supported by new and updated tests.
|
|
||
|
|
||
| def test_forward_backward_input_rejects_invalid_ppo_loss_fn_config_keys(): | ||
| with pytest.raises(ValidationError, match="Invalid loss_fn_config keys"): |
There was a problem hiding this comment.
The match string for pytest.raises is a bit broad. Making it more specific to the error message for the 'ppo' loss function will make this test more robust against future changes to other validation error messages.
| with pytest.raises(ValidationError, match="Invalid loss_fn_config keys"): | |
| with pytest.raises(ValidationError, match="Invalid loss_fn_config keys for loss_fn='ppo'"): |
See #1159 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1168" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This fixes the loss function keys and brings them in line with the upstream keys in https://tinker-docs.thinkingmachines.ai/losses