[train][examples] Fix 8 broken example scripts from skyrl-train migration#1230
Conversation
…tion - main_generate.py: migrate from @hydra.main to SkyRLTrainConfig.from_cli_overrides() - gspo/run_gspo_gsm8k.sh, sapo/run_sapo_gsm8k.sh: fix stale path + add $@ passthrough - lora/run_qwen2_5_0.5b_gsm8k_ppo_lora.sh: add missing critic_num_gpus_per_node for PPO - openenv/run_openenv.sh: fix package name openenv -> openenv-core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ription Add missing "$@" CLI override passthrough to harbor/run_codecontest.sh for consistency with other example scripts. Update PR description with expanded test results including DAPO AIME, SAPO AIME, and Harbor CodeContest (all verified with full training steps on 8xH100). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…mple bugs - Remove duplicate `importance_sampling` policy loss registration in `main_on_policy_distill.py` — this loss type is now built-in in ppo_utils.py, causing "already registered" error at startup. - Add missing `generator.sampling_params.logprobs=null` to `run_remote.sh` — the default logprobs=1 is not supported in remote inference mode, causing NotImplementedError during validation. - Update PR description with expanded test results (32 scripts tested). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request provides a series of well-justified fixes for 8 broken example scripts, stemming from a recent migration. The changes include updating stale paths, adding missing required configuration parameters, fixing an incorrect package name, and removing a duplicate function registration. Additionally, several scripts have been improved by adding "$@" to allow passthrough of command-line arguments, enhancing their usability and consistency. The migration of main_generate.py from Hydra to the custom SkyRLTrainConfig parser is a good modernization that aligns it with other entrypoints. The changes are correct and improve the overall quality and reliability of the example scripts.
There was a problem hiding this comment.
Code Review
This pull request addresses 8 bugs across various example scripts in examples/train and examples/train_integrations, which arose from the skyrl-train to skyrl/train migration. The fixes involve updating incorrect file paths in shell scripts, adding missing required configuration parameters for specific training setups (like LoRA with PPO), correcting a dependency package name for OpenEnv integration, and removing a duplicate policy loss registration. Additionally, several shell scripts have been updated to pass through command-line arguments, enhancing their flexibility. A notable change is the refactoring of the main_generate.py entrypoint to use SkyRLTrainConfig.from_cli_overrides() for configuration, moving away from the legacy Hydra loader to align with the project's current standards and fix issues with nested configuration keys. My review of the changes did not find any issues.
SumanthRH
left a comment
There was a problem hiding this comment.
Thanks!
Left a few nits for pending comment fixes
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
| return loss, {"clip_ratio": 0.0} | ||
|
|
||
|
|
||
| class OnPolicyDistillationExp(BasePPOExp): |
There was a problem hiding this comment.
🟡 Removing custom importance_sampling loss silently changes loss computation semantics for on-policy distillation example
The PR removes the custom importance_sampling policy loss registration from main_on_policy_distill.py to avoid a duplicate-registration crash with the built-in one in ppo_utils.py. However, the two implementations have fundamentally different loss reduction behavior.
Behavioral difference between old custom and built-in implementations
The old custom implementation (removed in this PR) used a normalized reduction:
loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len)
return loss, {"clip_ratio": 0.0}The built-in implementation at skyrl/backends/skyrl_train/utils/ppo_utils.py:966-980 uses a raw sum:
loss = (elementwise_loss * loss_mask).sum()
return loss, {"importance_ratio": mean_ratio.item()}reduce_loss with "seq_mean_token_sum_norm" computes per-sequence token-sum normalized by max_seq_len, then takes a batch mean — producing a loss that is invariant to batch size and sequence length. The built-in .sum() produces a loss that scales linearly with both, yielding very different gradient magnitudes.
Additionally, the metrics dict changed from {"clip_ratio": 0.0} to {"importance_ratio": ...}, which may affect downstream logging that expects the clip_ratio key.
Impact: Users of the on-policy distillation example will silently get a different (unnormalized) loss computation, which could lead to training instability or require learning rate re-tuning.
Prompt for agents
In examples/train/on_policy_distillation/main_on_policy_distill.py, the removal of the custom importance_sampling registration now silently delegates to the built-in importance_sampling_loss in skyrl/backends/skyrl_train/utils/ppo_utils.py (line 933-980), which uses a raw .sum() reduction instead of the original reduce_loss(..., 'seq_mean_token_sum_norm', config.max_seq_len). To preserve the original behavior, either:
1. Update the built-in importance_sampling_loss in skyrl/backends/skyrl_train/utils/ppo_utils.py (lines 964-980) to use reduce_loss instead of .sum(), and return {"clip_ratio": 0.0} in the metrics dict for consistency with other loss functions, OR
2. Keep the custom registration in main_on_policy_distill.py but use PolicyLossRegistry.unregister / PolicyLossRegistry.register to replace the built-in, OR
3. Add a comment in the on-policy distillation run scripts (run_on_policy_distill_math_qwen3_1.7b.sh and run_on_policy_distill_math_qwen3_4b.sh) noting that loss_reduction should be set to seq_mean_token_sum_norm and the learning rate may need adjustment since the built-in importance_sampling uses raw sum reduction.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Hi @CharlieFRuan , the importance_sampling reduction behavior is indeed changed from seq_mean_token_sum_norm to sum. Is this expected
Summary
Fix 8 bugs in
examples/trainandexamples/train_integrationsdiscovered while systematically running all example scripts after theskyrl-train→skyrl/trainmigration.Bugs fixed:
main_generate.py: Migrate from legacy@hydra.mainYAML config loader toSkyRLTrainConfig.from_cli_overrides(), matchingmain_base.py. The old loader didn't understand the new nestedgenerator.inference_engine.*config keys, causingKey 'inference_engine' is not in struct.gspo/run_gspo_gsm8k.shandsapo/run_sapo_gsm8k.sh: Fix stale pathexamples/gsm8k/run_gsm8k.sh→examples/train/gsm8k/run_gsm8k.sh. Also add"$@"passthrough so users can append CLI overrides.lora/run_qwen2_5_0.5b_gsm8k_ppo_lora.sh: Add missingtrainer.placement.critic_num_gpus_per_node— required for PPO (GAE) with colocated critic, otherwise hits assertionnum_policy_gpus and num_critic_gpus must be the same.openenv/run_openenv.sh: Fix package nameopenenv→openenv-coreto match the upstream PyPI metadata in the OpenEnv repo.harbor/run_codecontest.sh: Add missing"$@"passthrough so users can append CLI overrides (consistent with other example scripts).on_policy_distillation/main_on_policy_distill.py: Remove duplicate@register_policy_loss("importance_sampling")registration — this loss type is now built-in inppo_utils.py, causingValueError: policy loss 'importance_sampling' already registeredat startup.remote_inference_engine/run_remote.sh: Add missinggenerator.sampling_params.logprobs=null— the defaultlogprobs=1is not supported in remote inference mode, causingNotImplementedErrorduring validation.Test plan
Ran 32 example scripts on 8×H100 with tiny datasets and verified at least one full training step completes for each. Full results:
Passed (29
examples/train+ 3examples/train_integrations):gsm8k/run_gsm8k.sh,gsm8k/run_generation_gsm8k.sh(after fix)ppo/run_ppo.shmultiply/run_multiply.shsft/sft_trainer.pylora/run_qwen2_5_0.5b_gsm8k_grpo_lora.sh,lora/run_qwen2_5_0.5b_gsm8k_ppo_lora.sh(after fix)training_backends/fsdp/run_fsdp.sh,training_backends/fsdp/run_fsdp2.sh,training_backends/run_no_seq_pack.shasync/async_run_gsm8k.sh,fully_async/fully_async_run_gsm8k.shtis_correction/run_dapo_tis.shturn_level_rewards/run_gsm8k_multi_turn.shsearch/run_search.sh(Qwen2.5-1.5B-Instruct, with mock retrieval server — 2/2 training steps, full pipeline verified)text_to_sql/run_skyrl_sql.sh(Qwen2.5-Coder-7B-Instruct, with OmniSQL databases — 8 training steps, multi-turn SQL generation verified)on_policy_distillation/(after fix — Qwen3-1.7B-Base student+teacher, customapply_reward_kl_penaltyandno_opadvantage verified)remote_inference_engine/run_remote.sh(after fix — script bug fixed; NCCL weight sync on single machine is a pre-existing limitation, not a migration bug)train_integrations/harbor/run_codecontest.sh(Qwen3-8B, with Daytona sandbox)train_integrations/openenv/(import-verified after fix)Skipped (with rationale):
main_baseentrypoint already validated; blocked on dataset/API/server setup🤖 Generated with Claude Code