Skip to content

Improved LoRA weight swap and robust transitions_to_training_data#1368

Open
ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
ashutoshuiuc:pr/lora-weight-swap-training-data
Open

Improved LoRA weight swap and robust transitions_to_training_data#1368
ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
ashutoshuiuc:pr/lora-weight-swap-training-data

Conversation

@ashutoshuiuc
Copy link

@ashutoshuiuc ashutoshuiuc commented Mar 23, 2026

Summary

  • LoRA weight swap: Abort in-flight generation before swapping, remove old adapter, add new, reset prefix cache, track active_lora_id explicitly instead of querying list_loras(), monkey-patch _maybe_get_adapters for consistent adapter lookup
  • transitions_to_training_data: Validate None/empty observations and actions, track logprobs validity per-datum (handle external actions without logprobs), explicit length-mismatch checks, skip all-zero mask datums

Split from #1298 per maintainer feedback.

Closes #1297


Open with Devin

LoRA weight swap improvements:
- Abort in-flight generation before swapping adapters
- Remove old adapter before adding new one (prevents stale adapter buildup)
- Reset prefix cache after swap for correctness
- Track active_lora_id explicitly instead of querying list_loras()
- Monkey-patch _maybe_get_adapters for consistent adapter lookup in chat completions

transitions_to_training_data robustness:
- Validate None/empty observations and actions per-transition
- Track logprobs validity per-datum (handle external actions without logprobs)
- Allow response_logprobs=None in TrainingDatum when logprobs unavailable
- Explicit length-mismatch checks between response tokens, logprobs, and mask
- Skip datums with no action tokens (all-zero mask)

Closes NovaSky-AI#1297
Copilot AI review requested due to automatic review settings March 23, 2026 13:38
devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant robustness improvements to transitions_to_training_data and refactors LoRA weight swapping. The changes in transitions_to_training_data add comprehensive validation for transitions, observations, and actions, which is a great improvement. The LoRA weight swapping logic is now more explicit and robust, correctly handling adapter removal and tracking the active adapter.

My review includes two suggestions for improvement: one to simplify a boolean condition for better readability in utils.py, and another to use a more robust method for generating LoRA IDs in vllm_engine.py to prevent potential collisions.

Comment on lines +265 to +266
if not transition_has_valid_logprobs and has_valid_logprobs:
has_valid_logprobs = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional logic can be simplified for better readability. The current implementation is correct, but a more direct and idiomatic way to express that has_valid_logprobs should become False if any transition_has_valid_logprobs is False is to use a boolean and operation.

        has_valid_logprobs = has_valid_logprobs and transition_has_valid_logprobs

except Exception as e:
logger.error(f"Failed removing old LoRA: {e}")

new_id = int(time.time_ns() % 0x7FFFFFFF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using time.time_ns() for ID generation can lead to collisions if this function is called in rapid succession, especially with the modulo operation. A more robust approach is to use a random source. Since uuid4 is already imported, you can use it to generate a random 31-bit integer. This significantly reduces the chance of collision.

Note that the int() cast in the original code is redundant as time.time_ns() % ... already produces an integer.

Suggested change
new_id = int(time.time_ns() % 0x7FFFFFFF)
new_id = uuid4().int & 0x7FFFFFFF

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves runtime robustness in two areas: (1) vLLM LoRA adapter hot-swapping during training/inference weight sync, and (2) conversion of agent transitions into training datums, especially when observations/actions/logprobs may be missing or inconsistent.

Changes:

  • Update AsyncVLLMInferenceEngine LoRA swapping to explicitly track the active adapter ID, remove the previous adapter, add the new one, and reset the prefix cache; also monkey-patch adapter lookup for OpenAI-serving paths.
  • Make transitions_to_training_data more defensive: validate observations/actions, handle per-datum logprob availability, add explicit length checks, and skip invalid datums.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py Implements explicit active LoRA tracking and a more controlled LoRA swap flow for the async vLLM engine.
skyrl-agent/skyrl_agent/functional/utils.py Hardens transitions_to_training_data and updates TrainingDatum to allow missing logprobs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +462 to +465
"""Swap LoRA adapter: abort in-flight requests, remove old, add new, reset cache."""
await self.abort_generation()

if self._active_lora_id is not None:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_lora_from_disk() calls await self.abort_generation(), but AsyncVLLMInferenceEngine (and its base classes) don't define abort_generation. As written this will raise AttributeError the first time a LoRA swap is attempted. Add an abort_generation() implementation (likely similar to the unfinished-request abort logic used in sleep()), or replace this with the correct existing control method (pause_generation/resume_generation or direct engine.abort of unfinished request IDs).

Copilot uses AI. Check for mistakes.
- Add abort_generation() to AsyncVLLMInferenceEngine (was missing from base)
- Use uuid4 for LoRA adapter IDs instead of time.time_ns() (avoids collisions)
- Simplify has_valid_logprobs tracking with boolean AND
Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines +408 to +425
if self._is_lora:
original = self.openai_serving_chat._maybe_get_adapters
async_engine = self # capture outer self safely

def patched(self_chat, request, *args, **kwargs):
active_lora_id = getattr(async_engine, "_active_lora_id", None)
if active_lora_id is not None:
return LoRARequest(
lora_name=str(active_lora_id),
lora_int_id=active_lora_id,
lora_path="/dummy_lora_path",
)
return original(request, *args, **kwargs)

self.openai_serving_chat._maybe_get_adapters = MethodType(
patched,
self.openai_serving_chat,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 LoRA adapter not injected for /completions OpenAI endpoint

The PR monkey-patches _maybe_get_adapters on self.openai_serving_chat to inject the active LoRA adapter for chat completions, but the same patch is not applied to self.openai_serving_completion. In vLLM 0.16.0, the completion serving path (vllm/entrypoints/openai/completion/serving.py) also calls _maybe_get_adapters to resolve adapter requests. When a LoRA-enabled engine serves /completions requests (via vllm_engine.py:683), the active LoRA adapter won't be applied, and inference will fall back to the base model weights.

Prompt for agents
In skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py, the _maybe_get_adapters monkey-patch at lines 408-425 is applied only to self.openai_serving_chat but not to self.openai_serving_completion. Apply the same monkey-patch logic to self.openai_serving_completion as well. You can extract the patching logic into a helper function and call it for both serving objects. For example, after line 425, add the same patching for self.openai_serving_completion:

  original_completion = self.openai_serving_completion._maybe_get_adapters
  def patched_completion(self_completion, request, *args, **kwargs):
      active_lora_id = getattr(async_engine, "_active_lora_id", None)
      if active_lora_id is not None:
          return LoRARequest(
              lora_name=str(active_lora_id),
              lora_int_id=active_lora_id,
              lora_path="/dummy_lora_path",
          )
      return original_completion(request, *args, **kwargs)
  self.openai_serving_completion._maybe_get_adapters = MethodType(
      patched_completion,
      self.openai_serving_completion,
  )
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MoE (Qwen3) + FSDP2 training fixes: expert patching, NCCL deadlocks, batched broadcast, NUMA, Ulysses SP

2 participants