Extending Muon Optimizer Support for ZeRO Stage 3#7919
Extending Muon Optimizer Support for ZeRO Stage 3#7919PKUWZP merged 4 commits intodeepspeedai:masterfrom
Conversation
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 56451ea0fe
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * ( | ||
| (world_sz - len(params) % world_sz) % world_sz) | ||
| gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * ( | ||
| (world_sz - len(gathered_params_momentums) % world_sz) % world_sz) |
There was a problem hiding this comment.
Use distinct pad tensors in Muon all_gather
When a Muon subgroup has fewer tensors than the data-parallel world size, these padding expressions reuse the same torch.empty_like(...) object for every padded slot. The subsequent dist.all_gather then receives overlapping output buffers in the final partial chunk, which can error or corrupt the gathered gradients/momentum on small models or whenever the last chunk is not full.
Useful? React with 👍 / 👎.
- Move save_muon_momentum_buffer_in_memory config to DeepSpeedZeroConfig in config.py instead of reading inline from ds_config dict - Fix index bug: change muon_momentum_buffer_partitioned_groups_flat from list to dict keyed by sub-group index to avoid out-of-bounds access when non-muon groups precede muon groups - Add valid code path for non-swappable (GPU/CPU) optimizer without save_muon_momentum_buffer_in_memory, replacing ValueError - Validate that all Muon parameter groups share the same momentum (beta) - Parametrize tests for both True and False save_muon_momentum_buffer_in_memory - Update docs to show config under zero_optimization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
| continue | ||
|
|
||
| momentum_buffer = [] | ||
| if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: |
There was a problem hiding this comment.
The condition looks correct, but consider simplifying like the following (should be equivalent)
if self.save_muon_momentum_buffer_in_memory:
...
elif self._swappable_optimizer_subgroup(i):
...
else:
...
|
Suggest to turn on stage 3 for |
|
I remember deepspeed allow seperate learning rate for muon and adam (muon_lr and adam_lr), can we have a config in UT to cover this usage? |
delock
left a comment
There was a problem hiding this comment.
LGTM, some test case coverage suggestion are added to comments.
| "zero_optimization": { | ||
| "stage": 3, | ||
| "save_muon_momentum_buffer_in_memory": true | ||
| } |
There was a problem hiding this comment.
Did Muon optimizer for stage 3 mandates reduce_scatter = false? Does "reduce_scatter": false need to be added to the example?
Authors: @pengdurice and @PKUWZP
Create a separate PR based on #7798 with the same functional diff on a clean signed-off branch to resolve DCO issues.
We aim on adding Muon Optimizer to zero stage 3 in this draft PR:
self.fp32_partitioned_groups_flat; whendevice == NVME, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states.self.fp32_partitioned_groups_flatto save memory footprint. So, before the muon update, we need to performall_gatheron top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After theall_gather, the momentum buffers are partitioned and flattened again.Next steps: