Port https://github.com/NovaSky-AI/SkyRL/pull/1083 to skyrl folder#1162
Port https://github.com/NovaSky-AI/SkyRL/pull/1083 to skyrl folder#1162pcmoritz merged 2 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request successfully ports changes to enable and test gradient checkpointing. The updates to configuration descriptions are clearer, and the new tests for gradient checkpointing are comprehensive, covering multiple models. The refactoring in conftest.py to introduce a create_model helper is a good improvement for test structure. I have provided a couple of suggestions to further enhance this new helper for better versatility and to refactor a new test to use it, which would improve consistency across the test suite. Overall, this is a high-quality contribution.
| def create_model( | ||
| model_name: str, | ||
| config_cls: type[ModelConfig], | ||
| model_cls: type[ModelForCausalLM], | ||
| mesh_axes: tuple[str, ...], | ||
| *, | ||
| mesh_shape: tuple[int, ...] | None = None, | ||
| **config_kwargs, | ||
| mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, | ||
| seed: int = 0, | ||
| **config_kwargs: Any, | ||
| ) -> tuple[ModelConfig, ModelForCausalLM]: | ||
| """Create a JAX model and load weights from the HuggingFace cache.""" | ||
| weights_dir = resolve_model_path(model_name) | ||
| """Create a JAX model with initialized weights.""" | ||
| base_config = AutoConfig.from_pretrained(model_name) | ||
| config = config_cls(base_config, shard_attention_heads=True, **config_kwargs) | ||
| if mesh_shape is None: | ||
| mesh_shape = (1,) * len(mesh_axes) | ||
| mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes)) | ||
| if mesh_axis_types is None: | ||
| mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes) | ||
| mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types) | ||
| with jax.set_mesh(mesh): | ||
| model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) | ||
| model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed)) | ||
| return config, model | ||
|
|
||
|
|
||
| def load_model( | ||
| model_name: str, | ||
| config_cls: type[ModelConfig], | ||
| model_cls: type[ModelForCausalLM], | ||
| mesh_axes: tuple[str, ...], | ||
| *, | ||
| mesh_shape: tuple[int, ...] | None = None, | ||
| mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, | ||
| seed: int = 0, | ||
| **config_kwargs: Any, | ||
| ) -> tuple[ModelConfig, ModelForCausalLM]: | ||
| """Create a JAX model and load weights from the HuggingFace cache.""" | ||
| config, model = create_model( | ||
| model_name, | ||
| config_cls, | ||
| model_cls, | ||
| mesh_axes, | ||
| mesh_shape=mesh_shape, | ||
| mesh_axis_types=mesh_axis_types, | ||
| seed=seed, | ||
| **config_kwargs, | ||
| ) | ||
| weights_dir = resolve_model_path(model_name) | ||
| load_safetensors(weights_dir, config, model) | ||
| return config, model |
There was a problem hiding this comment.
The new create_model helper function is a great addition for simplifying test setup. However, it currently doesn't allow passing arguments to AutoConfig.from_pretrained, which is necessary for models that require trust_remote_code=True (as seen in test_deepseekv3.py).
To make this helper more versatile and promote consistency across tests, consider adding an auto_config_kwargs parameter to both create_model and load_model. This would allow all tests to use this common helper, avoiding duplicated model creation logic.
def create_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
auto_config_kwargs: dict[str, Any] | None = None,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model with initialized weights."""
if auto_config_kwargs is None:
auto_config_kwargs = {}
base_config = AutoConfig.from_pretrained(model_name, **auto_config_kwargs)
config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
if mesh_shape is None:
mesh_shape = (1,) * len(mesh_axes)
if mesh_axis_types is None:
mesh_axis_types = (jax.sharding.AxisType.Auto,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=mesh_axis_types)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(seed))
return config, model
def load_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None,
seed: int = 0,
auto_config_kwargs: dict[str, Any] | None = None,
**config_kwargs: Any,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
config, model = create_model(
model_name,
config_cls,
model_cls,
mesh_axes,
mesh_shape=mesh_shape,
mesh_axis_types=mesh_axis_types,
seed=seed,
auto_config_kwargs=auto_config_kwargs,
**config_kwargs,
)
weights_dir = resolve_model_path(model_name)
load_safetensors(weights_dir, config, model)
return config, model| model_name = "yujiepan/deepseek-v3-tiny-random" | ||
| base_config = PretrainedConfig.from_pretrained(model_name, trust_remote_code=True) | ||
|
|
||
| batch_size, seq_len = 2, 8 | ||
| mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) | ||
|
|
||
| results = {} | ||
| for use_checkpointing in [False, True]: | ||
| config = DeepseekV3Config( | ||
| base_config, | ||
| max_lora_adapters=1, | ||
| max_lora_rank=1, | ||
| shard_attention_heads=True, | ||
| gradient_checkpointing=use_checkpointing, | ||
| ) | ||
| with jax.set_mesh(mesh): | ||
| model = DeepseekV3ForCausalLM(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) |
There was a problem hiding this comment.
This test duplicates model creation logic that is now available in the create_model helper function in conftest.py. To improve code reuse and consistency, this test should be refactored to use create_model.
Assuming the suggested change to create_model (to support auto_config_kwargs) is applied, this test can be simplified as follows:
model_name = "yujiepan/deepseek-v3-tiny-random"
batch_size, seq_len = 2, 8
mesh_axes = ("fsdp", "ep", "tp")
results = {}
for use_checkpointing in [False, True]:
config, model = create_model(
model_name,
DeepseekV3Config,
DeepseekV3ForCausalLM,
mesh_axes,
mesh_shape=(1, 1, 1),
auto_config_kwargs={"trust_remote_code": True},
max_lora_adapters=1,
max_lora_rank=1,
gradient_checkpointing=use_checkpointing,
seed=0,
)
See #1083