Skip to content

Commit e8e1cf6

Browse files
authored
[skyrl-train][dependencies] separate vllm + megatron + bump vllm back to 0.11.0 + pin minimum uv version for extra-build-dependencies (NovaSky-AI#528)
## Separates vllm + megatron deps After NovaSky-AI#481, there were some megatron flashinfer issues with --extra vllm. This PR separates out the version of vllm that megatron relies on from the general vllm version, allowing us to bump vllm to 0.11.0 for the rest of the training stack. ## Update flash-attn installation Updates flash-attn installation to use the `extra-build-dependencies` feature from uv, requiring us to use a uv version >= 0.8.10. This feature allows us to do the following, removing the need to deal with markers + extras to specify a url source for each set of extras. ``` [tool.uv.extra-build-dependencies] flash-attn = [{requirement = "torch", match-runtime = true}] [tool.uv.extra-build-variables] flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"} [project.optional-dependencies] vllm = [ "vllm==0.11.0", "flash-attn==2.8.3", ... ] mcore = [ "flash-attn==2.7.4.post1" ... ] ```
1 parent 247959a commit e8e1cf6

File tree

11 files changed

+677
-520
lines changed

11 files changed

+677
-520
lines changed

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-es
66
RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \
77
&& sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run
88

9-
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
9+
RUN curl -LsSf https://astral.sh/uv/0.9.4/install.sh | sh
1010
RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc
1111

1212
RUN sudo apt-get update \

docker/Dockerfile.megatron

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-es
66
RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \
77
&& sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run
88

9-
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
9+
RUN curl -LsSf https://astral.sh/uv/0.9.4/install.sh | sh
1010
RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc
1111

1212

skyrl-train/docs/examples/megatron.rst

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,8 @@ After following the installation instructions, set the following environment var
104104
105105
Flash Attention
106106
~~~~~~~~~~~~~~~
107-
Next, in order to use flash attention with the megatron backend, you must use ``flash_attn`` version ``2.7.4.post1`` or lower for compatibility with ``TransformerEngine==2.5.0``.
108-
You can replace the ``flash-attn`` wheel in the ``pyproject.toml`` file with the following to use the ``2.7.4.post1`` release, and you can find wheels for other versions `here <https://github.com/Dao-AILab/flash-attention/releases>`_.
109-
110-
.. code-block:: bash
111-
112-
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" }
113-
107+
In order to use flash attention with the megatron backend, you must use ``flash_attn`` version ``2.7.4.post1`` or lower for compatibility with ``TransformerEngine==2.5.0``.
108+
This is handled in the ``pyproject.toml`` file for the ``mcore`` extra.
114109

115110
Configuration
116111
-------------

skyrl-train/examples/megatron/run_megatron.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export SKYRL_PYTHONPATH_EXPORT=1
2626
# make sure PYTHONPATH is set to the location of TransformerEngine installation
2727
export PYTHONPATH="$HOME/anaconda3/lib/python3.12/site-packages"
2828

29-
uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entrypoints.main_base \
29+
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
3030
data.train_data="['$DATA_DIR/train.parquet']" \
3131
data.val_data="['$DATA_DIR/validation.parquet']" \
3232
trainer.algorithm.advantage_estimator="grpo" \

skyrl-train/examples/megatron/run_megatron_moonlight.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export SKYRL_PYTHONPATH_EXPORT=1
3636
# make sure PYTHONPATH is set to the location of TransformerEngine installation
3737
export PYTHONPATH="$HOME/anaconda3/lib/python3.12/site-packages"
3838

39-
uv run --isolated --extra $INFERENCE_BACKEND --extra mcore --with blobfile -m skyrl_train.entrypoints.main_base \
39+
uv run --isolated --extra mcore --with blobfile -m skyrl_train.entrypoints.main_base \
4040
data.train_data="['$DATA_DIR/train.parquet']" \
4141
data.val_data="['$DATA_DIR/validation.parquet']" \
4242
trainer.algorithm.advantage_estimator="grpo" \

skyrl-train/examples/megatron/run_megatron_qwen3-235b-a22b.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ export SKYRL_PYTHONPATH_EXPORT=1
5151
# make sure PYTHONPATH is set to the location of TransformerEngine installation
5252
export PYTHONPATH="$HOME/anaconda3/lib/python3.12/site-packages"
5353

54-
uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entrypoints.main_base \
54+
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
5555
data.train_data="['$DATA_DIR/train.parquet']" \
5656
data.val_data="['$DATA_DIR/validation.parquet']" \
5757
trainer.algorithm.advantage_estimator="grpo" \

skyrl-train/examples/megatron/run_megatron_qwen3-30b-a3b.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export SKYRL_PYTHONPATH_EXPORT=1
3535
# make sure PYTHONPATH is set to the location of TransformerEngine installation
3636
export PYTHONPATH="$HOME/anaconda3/lib/python3.12/site-packages"
3737

38-
uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entrypoints.main_base \
38+
uv run --isolated --extra mcore -m skyrl_train.entrypoints.main_base \
3939
data.train_data="['$DATA_DIR/train.parquet']" \
4040
data.val_data="['$DATA_DIR/validation.parquet']" \
4141
trainer.algorithm.advantage_estimator="grpo" \

skyrl-train/examples/megatron/run_search_megatron.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export SKYRL_PYTHONPATH_EXPORT=1
3030
# make sure PYTHONPATH is set to the location of TransformerEngine installation
3131
export PYTHONPATH="$HOME/anaconda3/lib/python3.12/site-packages"
3232

33-
uv run --isolated --frozen --extra mcore --extra vllm -m skyrl_train.entrypoints.main_base \
33+
uv run --isolated --frozen --extra mcore -m skyrl_train.entrypoints.main_base \
3434
data.train_data="['${DATA_DIR}/train.parquet']" \
3535
data.val_data="['${DATA_DIR}/validation.parquet']" \
3636
trainer.algorithm.advantage_estimator="grpo" \

skyrl-train/pyproject.toml

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
]
4646

4747
[tool.uv]
48+
required-version = ">=0.8.10"
4849
conflicts = [
4950
[
5051
{ extra = "vllm" },
@@ -61,22 +62,28 @@ conflicts = [
6162
],
6263
[
6364
{ extra = "mcore" },
65+
{ extra = "vllm" },
6466
{ extra = "sglang" },
6567
{ extra = "flashrl" },
6668
]
6769
]
6870

71+
[tool.uv.extra-build-dependencies]
72+
flash-attn = [{requirement = "torch", match-runtime = true}]
73+
74+
[tool.uv.extra-build-variables]
75+
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}
76+
6977
[tool.uv.sources]
7078
skyrl-gym = { path = "./skyrl-gym" , editable = true }
7179
torch = { index = "pytorch-cu128" }
7280
torchvision = { index = "pytorch-cu128" }
73-
flash-attn = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" }
74-
# NOTE (sumanthrh): We explictly use a flashinfer wheel from their index.
75-
# The wheels on PyPI don't come with pre-compiled kernels and the package will JIT compile them at runtime which is slow.
76-
# additionally, different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
81+
# We use `flashinfer-jit-cache` to avoid slow JIT compilation on first run.
82+
# Different inference engines may pin different compatible flashinfer versions, so we provide the option to pin different versions for vllm/sglang
83+
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "extra == 'vllm'" }
7784
flashinfer-python = [
78-
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra =='vllm'" },
79-
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'vllm'" }
85+
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'mcore' and extra != 'vllm'" },
86+
{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
8087
]
8188

8289
[project.optional-dependencies]
@@ -104,14 +111,17 @@ sandboxes = [
104111
"litellm[proxy]>=1.67.5",
105112
]
106113
vllm = [
107-
"vllm==0.10.1.1",
108-
"torch==2.7.1",
114+
"vllm==0.11.0",
115+
"flash-attn==2.8.3",
116+
"torch==2.8.0",
109117
"flashinfer-python",
118+
"flashinfer-jit-cache",
110119
"torchvision"
111120
]
112121
sglang = [
113122
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1", # 0.4.9.post1 causes non-colocate weight broadcast to hang
114123
"flashinfer-python",
124+
"flash-attn==2.8.3",
115125
"torch==2.7.1",
116126
"torchvision",
117127
]
@@ -126,12 +136,18 @@ mcore = [
126136
# export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}"
127137
# uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose
128138
# "transformer-engine[pytorch]==2.5.0",
139+
"flash-attn==2.7.4.post1",
140+
"vllm==0.10.1.1",
141+
"torch==2.7.1",
142+
"flashinfer-python",
143+
"torchvision",
129144
"mbridge==0.15.1",
130145
"megatron-core==0.13.0",
131146
]
132147
flashrl = [
133148
# NOTE: Custom vLLM wheel must be installed separately.
134149
# See examples/flash_rl/README.md for installation instructions.
150+
"flash-attn==2.8.3",
135151
"torch==2.7.0",
136152
"flashinfer-python",
137153
"torchvision",
@@ -148,6 +164,11 @@ name = "pytorch-cu128"
148164
url = "https://download.pytorch.org/whl/cu128"
149165
explicit = true
150166

167+
[[tool.uv.index]]
168+
name = "flashinfer-cu128"
169+
url = "https://flashinfer.ai/whl/cu128"
170+
explicit = true
171+
151172
[tool.setuptools]
152173
include-package-data = true
153174

skyrl-train/tests/gpu/test_megatron_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Run with:
3-
SKYRL_PYTHONPATH_EXPORT=1 uv run --isolated --extra dev --extra vllm --extra mcore -- pytest tests/gpu/test_megatron_worker.py
3+
export PYTHONPATH=/home/ray/anaconda3/lib/python3.12/site-packages
4+
SKYRL_PYTHONPATH_EXPORT=1 uv run --isolated --extra dev --extra mcore -- pytest tests/gpu/test_megatron_worker.py
45
"""
56

67
import ray

0 commit comments

Comments
 (0)