Skip to content

[GRPO] add chunked grpo streaming over vocab#1160

Open
kashif wants to merge 7 commits intolinkedin:mainfrom
kashif:chunked_grpo_streaming_origin_main
Open

[GRPO] add chunked grpo streaming over vocab#1160
kashif wants to merge 7 commits intolinkedin:mainfrom
kashif:chunked_grpo_streaming_origin_main

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Mar 23, 2026

Summary

This PR fixes the chunked GRPO loss to compute only selected-token log-probs by streaming over the vocab dimension. This reduces peak memory for the fused-linear chunked path and preserves the existing high-level fused-linear API.

We also fixes the luspo reduction in the chunked path to match TRL exactly, and tightens the torch.compile boundary so we only compile the pure loss computation instead of compiling through the closure that calls torch.autograd.grad.

two_way_grpo_time two_way_grpo_memory

So now the two implementations are correctly implementing the trade-offs of their design.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kashif
Copy link
Contributor Author

kashif commented Mar 23, 2026

cc @vaibhavjindal for your review

@kashif kashif force-pushed the chunked_grpo_streaming_origin_main branch from 0e54614 to e25f787 Compare March 24, 2026 21:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant