Unify Megatron and FSDP training interfaces with forward_backward + optim_step#901
Merged
erictang000 merged 9 commits intoNovaSky-AI:mainfrom Jan 31, 2026
Merged
Conversation
…ptim_step - Add forward_backward() and optim_step() methods to MegatronPolicyWorkerBase - Update trainer to use unified interface for both strategies - Remove strategy branching in train_critic_and_policy() - Mark ppo_train() as deprecated (kept for backward compatibility) - Update test_megatron_worker.py to use new interface Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request successfully unifies the training interfaces for Megatron and FSDP strategies by introducing forward_backward and optim_step methods in MegatronPolicyWorkerBase. The trainer.py file is updated to use this unified interface, removing strategy-specific branching, which significantly improves code maintainability and clarity. The ppo_train method is appropriately marked as deprecated, and the tests are updated to reflect these changes. Overall, this is a well-executed refactoring that aligns with the goal of creating a more consistent training pipeline.
…nterface - Remove ppo_train from MegatronPolicyWorkerBase and WorkerDispatch - Update test_megatron_dp, test_megatron_offload to use forward_backward + optim_step - Update test_save_load_model.py and test_save_load_checkpoint.py for unified interface - Simplify _normalize_mini_batch_size (no longer needs policy_mini_batch_size_per_gpu) Both FSDP and Megatron now use the same forward_backward + optim_step interface. Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The method just set _micro_batches_accumulated = 0, which can be done directly in __init__. This removes unnecessary indirection and the vestigial mesh_rank guard that was no longer needed. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
erictang000
reviewed
Jan 20, 2026
…ggs/megatron_refactor
Collaborator
Collaborator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.


Summary
forward_backward()andoptim_step()methods toMegatronPolicyWorkerBaseto match FSDP worker interfaceppo_train()as deprecated (kept for backward compatibility)test_megatron_worker.pyto use the new interfaceget_lrandset_lrto the megatron worker to be in line with behavior from Add set_lr() for dynamic learning rate updates from Tinker #978This brings Megatron up to parity with FSDP following the refactoring in PR #859.
Test plan
test_megatron_worker.pyto verify forward_backward + optim_step works correctlyCo-Authored-By: Eric Tang erictang000@gmail.com