Skip to content

Extending Muon Optimizer Support for ZeRO Stage 3#7919

Merged
PKUWZP merged 4 commits intodeepspeedai:masterfrom
PKUWZP:pr-7798-clean
Mar 26, 2026
Merged

Extending Muon Optimizer Support for ZeRO Stage 3#7919
PKUWZP merged 4 commits intodeepspeedai:masterfrom
PKUWZP:pr-7798-clean

Conversation

@PKUWZP
Copy link
Collaborator

@PKUWZP PKUWZP commented Mar 23, 2026

Authors: @pengdurice and @PKUWZP

Create a separate PR based on #7798 with the same functional diff on a clean signed-off branch to resolve DCO issues.

We aim on adding Muon Optimizer to zero stage 3 in this draft PR:

  • Created a dedicated momentum buffer in zero stage 3 optimizer to save the momentum buffers specifically for Muon Optimizer.
  • The optimizer states can be dispatched into 3 devices: GPU, CPU and NVME. For GPU and CPU, we just make the new buffers the same device of self.fp32_partitioned_groups_flat; when device == NVME, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states.
  • The new momentum buffers are also partitioned like self.fp32_partitioned_groups_flat to save memory footprint. So, before the muon update, we need to perform all_gather on top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After the all_gather, the momentum buffers are partitioned and flattened again.

Next steps:

  • Explore quantization of momentum buffers for saving memory
  • Explore using highly optimized Adam / AdamW Optimizers

Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 56451ea0fe

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1529 to +1532
grads_pad = [param.grad for param in params] + [torch.empty_like(params[-1].grad)] * (
(world_sz - len(params) % world_sz) % world_sz)
gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * (
(world_sz - len(gathered_params_momentums) % world_sz) % world_sz)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use distinct pad tensors in Muon all_gather

When a Muon subgroup has fewer tensors than the data-parallel world size, these padding expressions reuse the same torch.empty_like(...) object for every padded slot. The subsequent dist.all_gather then receives overlapping output buffers in the final partial chunk, which can error or corrupt the gathered gradients/momentum on small models or whenever the last chunk is not full.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @PKUWZP, thank you for extending Muon Optimizer support. This is a significant improvement.
I left some comments about some regression and enhancement.

PKUWZP and others added 2 commits March 25, 2026 17:08
- Move save_muon_momentum_buffer_in_memory config to DeepSpeedZeroConfig
  in config.py instead of reading inline from ds_config dict
- Fix index bug: change muon_momentum_buffer_partitioned_groups_flat from
  list to dict keyed by sub-group index to avoid out-of-bounds access
  when non-muon groups precede muon groups
- Add valid code path for non-swappable (GPU/CPU) optimizer without
  save_muon_momentum_buffer_in_memory, replacing ValueError
- Validate that all Muon parameter groups share the same momentum (beta)
- Parametrize tests for both True and False save_muon_momentum_buffer_in_memory
- Update docs to show config under zero_optimization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a comment regarding simplifying conditions, but overall looks good to me. Thank you for the great work, @PKUWZP!

continue

momentum_buffer = []
if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition looks correct, but consider simplifying like the following (should be equivalent)

  if self.save_muon_momentum_buffer_in_memory:
      ...
  elif self._swappable_optimizer_subgroup(i):
      ...
  else:
      ...

@delock
Copy link
Collaborator

delock commented Mar 26, 2026

Suggest to turn on stage 3 for test_muon_partial_training.py as well, this would check the case when almost all parameters are freezed so all trainable parameters use muon optimizer.

@delock
Copy link
Collaborator

delock commented Mar 26, 2026

I remember deepspeed allow seperate learning rate for muon and adam (muon_lr and adam_lr), can we have a config in UT to cover this usage?

Copy link
Collaborator

@delock delock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some test case coverage suggestion are added to comments.

"zero_optimization": {
"stage": 3,
"save_muon_momentum_buffer_in_memory": true
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did Muon optimizer for stage 3 mandates reduce_scatter = false? Does "reduce_scatter": false need to be added to the example?

@PKUWZP PKUWZP merged commit 956ec6f into deepspeedai:master Mar 26, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants