Skip to content

Commit be7ee34

Browse files
authored
[train] Fix issue with unset pad_token_id (#1232)
# What does this PR do? Fixes #1231 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1232" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
1 parent 83d2bd7 commit be7ee34

File tree

5 files changed

+29
-20
lines changed

5 files changed

+29
-20
lines changed

skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.distributed
44
import ray
5-
from transformers import AutoTokenizer, AutoConfig
5+
from transformers import AutoConfig
66
from huggingface_hub import snapshot_download
77

88
import os
@@ -42,7 +42,7 @@
4242
from skyrl.backends.skyrl_train.workers.megatron.megatron_model_wrapper import MegatronModelWrapper
4343
from skyrl.backends.skyrl_train.utils.profiler import Profiler
4444
from skyrl.backends.skyrl_train.weight_sync import WeightExtractor, WeightChunk
45-
45+
from skyrl.utils.tok import get_tokenizer
4646

4747
if TYPE_CHECKING:
4848
from skyrl.backends.skyrl_train.inference_engines.base import InferenceEngineInterface
@@ -205,7 +205,7 @@ def init_configs(
205205
"""
206206
Initialize the Megatron-Bridge bridge and provider objects + hf_config and tokenizer
207207
"""
208-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
208+
tokenizer = get_tokenizer(model_path, trust_remote_code=True)
209209
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
210210

211211
override_config_kwargs = {

skyrl/backends/skyrl_train_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
create_ray_wrapped_inference_engines,
3131
)
3232
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
33+
from skyrl.utils.tok import get_tokenizer
3334

3435

3536
class SkyRLTrainBackendOverrides(BaseModel, extra="allow"):
@@ -112,7 +113,7 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
112113
self._model_metadata: types.ModelMetadata | None = None
113114
self._cfg = None
114115
self._dispatch: WorkerDispatch | None = None
115-
self._tokenizer = AutoTokenizer.from_pretrained(self.base_model)
116+
self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model)
116117
self._inference_engine_client = None
117118
self._inference_engines_initialized = False
118119

skyrl/train/entrypoints/main_base.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ray.util.placement_group import placement_group, PlacementGroup
66

7-
from transformers import AutoTokenizer, PreTrainedTokenizerBase
7+
from transformers import PreTrainedTokenizerBase
88
from skyrl.train.dataset import PromptDataset
99
from skyrl.train.utils import validate_cfg
1010

@@ -24,6 +24,7 @@
2424
import os
2525
from loguru import logger
2626
from skyrl.train.utils.tracking import Tracking
27+
from skyrl.utils.tok import get_tokenizer
2728
import multiprocessing as mp
2829
import asyncio
2930

@@ -122,7 +123,12 @@ def __init__(self, cfg: SkyRLTrainConfig):
122123
cfg: The fully resolved SkyRLTrainConfig instance.
123124
"""
124125
self.cfg = cfg
125-
self.tokenizer = self.get_tokenizer()
126+
self.tokenizer = get_tokenizer(
127+
self.cfg.trainer.policy.model.path,
128+
trust_remote_code=True,
129+
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
130+
padding_side="left",
131+
)
126132
self.train_dataset = self.get_train_dataset()
127133
self.eval_dataset = self.get_eval_dataset()
128134
self.colocate_pg = self.get_colocate_pg()
@@ -135,19 +141,6 @@ def __init__(self, cfg: SkyRLTrainConfig):
135141
def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
136142
return get_config_as_yaml_str(cfg)
137143

138-
def get_tokenizer(self, padding_side="left"):
139-
"""Initializes a tokenizer for the given model."""
140-
tokenizer = AutoTokenizer.from_pretrained(
141-
self.cfg.trainer.policy.model.path,
142-
trust_remote_code=True,
143-
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
144-
)
145-
tokenizer.padding_side = padding_side
146-
if tokenizer.pad_token is None:
147-
tokenizer.pad_token = tokenizer.eos_token
148-
tokenizer.pad_token_id = tokenizer.eos_token_id
149-
return tokenizer
150-
151144
def get_train_dataset(self):
152145
"""Initializes the training dataset.
153146

skyrl/utils/tok.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Tokenization related utilities"""
2+
3+
from transformers import AutoTokenizer
4+
5+
6+
def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer:
7+
"""Gets tokenizer for the given base model with the given parameters
8+
9+
Sets the pad token ID to EOS token ID if `None`"""
10+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_kwargs)
11+
if tokenizer.pad_token_id is None:
12+
tokenizer.pad_token_id = tokenizer.eos_token_id
13+
tokenizer.pad_token = tokenizer.eos_token
14+
return tokenizer

tests/backends/skyrl_train/gpu/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient
3636
from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup
3737
from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter
38+
from skyrl.utils.tok import get_tokenizer
3839

3940
TEST_DATA_PATH = os.path.expanduser("~/data/gsm8k/validation.parquet")
4041

@@ -460,7 +461,7 @@ def create(
460461
# Extract served_model_name from config if set
461462
served_model_name = ie_cfg.served_model_name
462463

463-
tokenizer = AutoTokenizer.from_pretrained(cfg.trainer.policy.model.path)
464+
tokenizer = get_tokenizer(cfg.trainer.policy.model.path)
464465

465466
# Return both router and server group if created to keep references alive
466467
router = None

0 commit comments

Comments
 (0)