Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ResumeMode,
DynamicSamplingState,
)
from skyrl_train.utils.utils import configure_ray_worker_logging


class RayPPOTrainer:
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
self.dynamic_sampling_state: Optional[DynamicSamplingState] = None

self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None
configure_ray_worker_logging()

def build_dataloader(self, dataset: PromptDataset, is_train=True):
"""
Expand Down
40 changes: 39 additions & 1 deletion skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time

import sys
import logging
import ray
import torch
from loguru import logger
Expand Down Expand Up @@ -465,6 +466,43 @@ def prepare_runtime_environment(cfg: DictConfig) -> dict[str, str]:
return env_vars


def configure_ray_worker_logging() -> None:
"""
In Ray workers, stderr/stdout are not TTYs, so Loguru disables color.
`configure_worker_logging` is used within each Ray worker to
force color and formatting (e.g., bold) and route stdlib `logging`
through Loguru so third‑party logs match formatting
"""
# 1) Loguru formatting (force colors)
logger.remove()
logger.level("INFO", color="<bold><green>")
logger.add(
sys.stderr,
colorize=True, # keep ANSI even without a TTY
enqueue=True,
backtrace=False,
diagnose=False,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level>",
)

# 2) Route stdlib logging -> Loguru (so vLLM/transformers/etc. are formatted)
class _InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
logger.opt(depth=6, exception=record.exc_info).log(level, record.getMessage())

logging.root.handlers = [_InterceptHandler()]
level_name = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
logging.root.setLevel(level)


def initialize_ray(cfg: DictConfig):
"""
Initialize Ray cluster with prepared runtime environment.
Expand Down