Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,9 @@ def zero_legacy_stage1(self):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def zero_save_muon_momentum_buffer_in_memory(self):
return self._config.zero_config.save_muon_momentum_buffer_in_memory

def tensor_parallel_config(self):
return self._config.tensor_parallel_config

Expand Down Expand Up @@ -1733,7 +1736,6 @@ def _configure_basic_optimizer(self, model_parameters):
optimizer = MuSGD(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUON_OPTIMIZER:
zero_stage = self.zero_optimization_stage()
assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3"
if not all([hasattr(p, 'use_muon') for p in model_parameters]):
msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \
"please set by `param.use_muon = True / False` for all params"
Expand Down Expand Up @@ -2045,6 +2047,7 @@ def _configure_zero_optimizer(self, optimizer):
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
enable_sanity_checks=self.is_sanity_checks_enabled(),
cpuadam_cores_perc=self.cpuadam_cores_perc(),
save_muon_momentum_buffer_in_memory=self.zero_save_muon_momentum_buffer_in_memory(),
)

else:
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Enable internal sanity checks, which could be useful for debugging
"""

save_muon_momentum_buffer_in_memory: bool = False
"""
When using the Muon optimizer with ZeRO Stage 3, keeps the Muon momentum
buffer in GPU/CPU memory instead of swapping to NVMe with other optimizer
states. Only relevant when using NVMe offloading.
"""

leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
"""
Configuration for modules that should be treated as ZeRO3 leaf modules.
Expand Down
Loading
Loading