Skip to content

Commit c76ecd7

Browse files
authored
Merge pull request #37 from neurolabusc/WebGPU
Web gpu
2 parents 70387db + ba8f9a6 commit c76ecd7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+11372
-97
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,12 @@
44
*.pyo
55
*.pyd
66

7+
# Build artifacts
8+
*.egg-info/
9+
dist/
10+
build/
11+
12+
# Test outputs
13+
*.trk
14+
*.trx
15+
*.nii.gz

CLAUDE.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
GPUStreamlines (`cuslines`) is a GPU-accelerated tractography package for diffusion MRI. It supports **three GPU backends**: NVIDIA CUDA, Apple Metal (Apple Silicon), and WebGPU (cross-platform via wgpu-py). Backend is auto-detected at import time in `cuslines/__init__.py` (priority: Metal → CUDA → WebGPU). Kernels are compiled at runtime (NVRTC for CUDA, `MTLDevice.newLibraryWithSource` for Metal, `device.create_shader_module` for WebGPU/WGSL).
8+
9+
## Build & Run
10+
11+
```bash
12+
# Install (pick your backend)
13+
pip install ".[cu13]" # CUDA 13
14+
pip install ".[cu12]" # CUDA 12
15+
pip install ".[metal]" # Apple Metal (Apple Silicon)
16+
pip install ".[webgpu]" # WebGPU (cross-platform: NVIDIA, AMD, Intel, Apple)
17+
18+
# From PyPI
19+
pip install "cuslines[cu13]"
20+
pip install "cuslines[metal]"
21+
pip install "cuslines[webgpu]"
22+
23+
# GPU run (downloads HARDI dataset if no data passed)
24+
python run_gpu_streamlines.py --output-prefix small --nseeds 1000 --ngpus 1
25+
26+
# Force a specific backend
27+
python run_gpu_streamlines.py --device=webgpu --output-prefix small --nseeds 1000
28+
29+
# CPU reference run (for comparison/debugging)
30+
python run_gpu_streamlines.py --device=cpu --output-prefix small --nseeds 1000
31+
32+
# Docker
33+
docker build -t gpustreamlines .
34+
```
35+
36+
There is no dedicated test or lint suite. Validate by comparing CPU vs GPU outputs on the same seeds.
37+
38+
## Architecture
39+
40+
**Two-layer design**: Python orchestration + GPU kernels compiled at runtime. Three parallel backend implementations share the same API surface.
41+
42+
```
43+
run_gpu_streamlines.py # CLI entry: DIPY model fitting → CPU or GPU tracking
44+
cuslines/
45+
__init__.py # Auto-detects Metal → CUDA → WebGPU backend at import
46+
boot_utils.py # Shared bootstrap matrix preparation (OPDT/CSA) for all backends
47+
cuda_python/ # CUDA backend
48+
cu_tractography.py # GPUTracker: context manager, multi-GPU allocation
49+
cu_propagate_seeds.py # SeedBatchPropagator: chunked seed processing
50+
cu_direction_getters.py # Direction getter ABC + Boot/Prob/PTT implementations
51+
cutils.py # REAL_DTYPE, REAL3_DTYPE, checkCudaErrors(), ModelType enum
52+
_globals.py # AUTO-GENERATED from globals.h (never edit manually)
53+
cuda_c/ # CUDA kernel source
54+
globals.h # Source-of-truth for constants (REAL_SIZE, thread config)
55+
generate_streamlines_cuda.cu, boot.cu, ptt.cu, tracking_helpers.cu, utils.cu
56+
cudamacro.h, cuwsort.cuh, ptt.cuh, disc.h
57+
metal/ # Metal backend (mirrors cuda_python/)
58+
mt_tractography.py, mt_propagate_seeds.py, mt_direction_getters.py, mutils.py
59+
metal_shaders/ # MSL kernel source (mirrors cuda_c/)
60+
globals.h, types.h, philox_rng.h
61+
generate_streamlines_metal.metal, boot.metal, ptt.metal
62+
tracking_helpers.metal, utils.metal, warp_sort.metal
63+
webgpu/ # WebGPU backend (mirrors metal/)
64+
wg_tractography.py, wg_propagate_seeds.py, wg_direction_getters.py, wgutils.py
65+
benchmark.py # Cross-backend benchmark: python -m cuslines.webgpu.benchmark
66+
wgsl_shaders/ # WGSL kernel source (mirrors metal_shaders/)
67+
globals.wgsl, types.wgsl, philox_rng.wgsl
68+
utils.wgsl, warp_sort.wgsl, tracking_helpers.wgsl
69+
generate_streamlines.wgsl # Prob/PTT buffer bindings + Prob getNum/gen kernels
70+
boot.wgsl # Boot direction getter kernels (standalone module)
71+
disc.wgsl, ptt.wgsl # PTT support
72+
```
73+
74+
**Data flow**: DIPY preprocessing → seed generation → GPUTracker context → SeedBatchPropagator chunks seeds across GPUs → kernel launch → stream results to TRK/TRX output.
75+
76+
**Direction getters** (subclasses of `GPUDirectionGetter`):
77+
- `BootDirectionGetter` — bootstrap sampling from SH coefficients (OPDT/CSA models)
78+
- `ProbDirectionGetter` — probabilistic selection from ODF/PMF (CSD model)
79+
- `PttDirectionGetter` — Probabilistic Tracking with Turning (CSD model)
80+
81+
Each has `from_dipy_*()` class methods for initialization from DIPY models.
82+
83+
## Critical Conventions
84+
85+
- **`_globals.py` is auto-generated** from `cuslines/cuda_c/globals.h` during `setup.py` build via `defines_to_python()`. Never edit it manually; change `globals.h` and rebuild.
86+
- **GPU arrays must be C-contiguous** — always use `np.ascontiguousarray()` and project scalar types (`REAL_DTYPE`, `REAL_SIZE` from `cutils.py` or `mutils.py`).
87+
- **All CUDA API calls must be wrapped** with `checkCudaErrors()`.
88+
- **Angle units**: CLI accepts degrees, internals convert to radians before the GPU layer.
89+
- **Multi-GPU**: CUDA uses explicit `cudaSetDevice()` calls; Metal and WebGPU are single-GPU only.
90+
- **CPU/GPU parity**: `run_gpu_streamlines.py` maintains parallel CPU and GPU code paths — keep both in sync when changing arguments or model-selection logic.
91+
- **Logger**: use `logging.getLogger("GPUStreamlines")`.
92+
- **Kernel compilation**: CUDA uses `cuda.core.Program` with NVIDIA headers. Metal uses `MTLDevice.newLibraryWithSource_options_error_()` with MSL source concatenated from `metal_shaders/`. WebGPU uses `device.create_shader_module()` with WGSL source concatenated from `wgsl_shaders/`.
93+
94+
## Metal Backend Notes
95+
96+
- **Unified memory**: Metal buffers use `storageModeShared` — numpy arrays are directly GPU-accessible (zero memcpy per batch, vs ~6 in CUDA).
97+
- **float3 alignment**: All buffers use `packed_float3` (12 bytes) with `load_f3()`/`store_f3()` helpers. Metal `float3` is 16 bytes in registers.
98+
- **Page alignment**: Use `aligned_array()` from `mutils.py` for arrays passed to `newBufferWithBytesNoCopy`.
99+
- **No double precision**: Only `REAL_SIZE=4` (float32) is ported.
100+
- **Warp primitives**: `__shfl_sync``simd_shuffle`, `__ballot_sync``simd_ballot`. SIMD width = 32.
101+
- **SH basis**: Always use `real_sh_descoteaux(legacy=True)` for all matrices. See `boot_utils.py`.
102+
103+
## WebGPU Backend Notes
104+
105+
- **Cross-platform**: wgpu-py maps to Metal (macOS), Vulkan (Linux/Windows), D3D12 (Windows). Install: `pip install "cuslines[webgpu]"`.
106+
- **Explicit readbacks**: `device.queue.read_buffer()` for GPU→CPU (~3 per seed batch, matching CUDA's cudaMemcpy pattern).
107+
- **WGSL shaders**: Concatenated in dependency order by `compile_program()`. Boot compiles standalone; Prob/PTT share `generate_streamlines.wgsl`.
108+
- **Buffer binding**: Boot needs 17 buffers across 3 bind groups. Prob/PTT use 2 bind groups. `layout="auto"` only includes reachable bindings.
109+
- **Subgroups required**: Device feature `"subgroup"` (singular, not `"subgroups"`). Naga does NOT support `enable subgroups;` directive.
110+
- **WGSL constraints**: No `ptr<storage>` parameters (use module-scope accessors). `var<workgroup>` sizes must be compile-time constants. PhiloxState is pass-by-value (return result structs).
111+
- **Boot standalone module**: `_kernel_files()` returns `[]` to avoid `params` struct redefinition.
112+
- **Benchmark**: `python -m cuslines.webgpu.benchmark --nseeds 10000` — auto-detects all backends.
113+
114+
## Key Dependencies
115+
116+
- `dipy` — diffusion models, CPU direction getters, seeding, stopping criteria
117+
- `nibabel` — NIfTI/TRK file I/O (`StatefulTractogram`)
118+
- `trx-python` — TRX format support (memory-mapped, for large outputs)
119+
- `cuda-python` / `cuda-core` / `cuda-cccl` — CUDA Python bindings, kernel compilation, C++ headers
120+
- `pyobjc-framework-Metal` / `pyobjc-framework-MetalPerformanceShaders` — Metal Python bindings (macOS only)
121+
- `wgpu` — WebGPU Python bindings (wgpu-native, cross-platform)
122+
- `numpy` — array operations throughout

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ SHELL ["/bin/bash", "-c"]
55

66
ENV DEBIAN_FRONTEND=noninteractive
77

8-
RUN apt-get update && apt-get install --assume-yes curl
8+
RUN apt-get update && apt-get install --assume-yes curl git
99

1010
RUN curl -L "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
1111
-o "/tmp/Miniconda3.sh"

README.md

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# GPUStreamlines
22

33
## Installation
4-
To install from pypi, simply run `pip install "cuslines[cu13]"` or `pip install "cuslines[cu12]"` depending on your CUDA version.
4+
To install from pypi:
5+
```
6+
pip install "cuslines[cu13]" # CUDA 13 (NVIDIA)
7+
pip install "cuslines[cu12]" # CUDA 12 (NVIDIA)
8+
pip install "cuslines[metal]" # Apple Metal (Apple Silicon)
9+
pip install "cuslines[webgpu]" # WebGPU (cross-platform: NVIDIA, AMD, Intel, Apple)
10+
```
511

6-
To install from dev, simply run `pip install ".[cu13]"` or `pip install ".[cu12]"` in the top-level repository directory.
12+
To install from dev:
13+
```
14+
pip install ".[cu13]" # CUDA 13
15+
pip install ".[cu12]" # CUDA 12
16+
pip install ".[metal]" # Apple Metal
17+
pip install ".[webgpu]" # WebGPU (any GPU)
18+
```
719

820
## Running the examples
921
This repository contains several example usage scripts.
1022

11-
The script `run_gpu_streamlines.py` demonstrates how to run any diffusion MRI dataset on the GPU. It can also run on the CPU for reference, if the argument `--device=cpu` is used. If not data is passed, it will donaload and use the HARDI dataset.
23+
The script `run_gpu_streamlines.py` demonstrates how to run any diffusion MRI dataset on the GPU. It can also run on the CPU for reference, if the argument `--device=cpu` is used. If no data is passed, it will download and use the HARDI dataset.
1224

1325
To run the baseline CPU example on a random set of 1000 seeds, this is the command and example output:
1426
```
@@ -52,6 +64,12 @@ Note that if you experience memory errors, you can adjust the `--chunk-size` fla
5264

5365
To run on more seeds, we suggest setting the `--write-method trx` flag in the GPU script to not get bottlenecked by writing files.
5466

67+
## GPU vs CPU differences
68+
69+
GPU backends (CUDA, Metal, and WebGPU) operate in float32 while DIPY uses float64. This causes slightly different peak selection at fiber crossings where ODF peaks have similar magnitudes. In practice the GPU produces comparable streamline counts and commissural fiber density, with modestly longer fibers on average. See [cuslines/webgpu/README.md](cuslines/webgpu/README.md) for cross-platform benchmarks and [cuslines/metal/README.md](cuslines/metal/README.md) for Metal-specific details.
70+
71+
The WebGPU backend runs on any GPU (NVIDIA, AMD, Intel, Apple) via [wgpu-py](https://github.com/pygfx/wgpu-py). It is auto-detected when no vendor-specific backend is available. See `python -m cuslines.webgpu.benchmark` for a self-contained benchmark across all available backends.
72+
5573
## Running on AWS with Docker
5674
First, set up an AWS instance with GPU and ssh into it (we recommend a P3 instance with at least 1 V100 16 GB GPU and a Deep Learning AMI Ubuntu 18.04 v 33.0.). Then do the following:
5775
1. Log in to GitHub docker registry:

cuslines/__init__.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,71 @@
1-
from .cuda_python import (
2-
GPUTracker,
3-
ProbDirectionGetter,
4-
PttDirectionGetter,
5-
BootDirectionGetter
6-
)
1+
import platform as _platform
2+
3+
4+
def _detect_backend():
5+
"""Auto-detect the best available GPU backend."""
6+
system = _platform.system()
7+
if system == "Darwin":
8+
try:
9+
import Metal
10+
11+
if Metal.MTLCreateSystemDefaultDevice() is not None:
12+
return "metal"
13+
except ImportError:
14+
pass
15+
try:
16+
from cuda.bindings import runtime
17+
18+
count = runtime.cudaGetDeviceCount()
19+
if count[1] > 0:
20+
return "cuda"
21+
except (ImportError, Exception):
22+
pass
23+
try:
24+
import wgpu
25+
26+
adapter = wgpu.gpu.request_adapter_sync()
27+
if adapter is not None:
28+
return "webgpu"
29+
except (ImportError, Exception):
30+
pass
31+
return None
32+
33+
34+
BACKEND = _detect_backend()
35+
36+
if BACKEND == "metal":
37+
from cuslines.metal import (
38+
MetalGPUTracker as GPUTracker,
39+
MetalProbDirectionGetter as ProbDirectionGetter,
40+
MetalPttDirectionGetter as PttDirectionGetter,
41+
MetalBootDirectionGetter as BootDirectionGetter,
42+
)
43+
elif BACKEND == "cuda":
44+
from cuslines.cuda_python import (
45+
GPUTracker,
46+
ProbDirectionGetter,
47+
PttDirectionGetter,
48+
BootDirectionGetter,
49+
)
50+
elif BACKEND == "webgpu":
51+
from cuslines.webgpu import (
52+
WebGPUTracker as GPUTracker,
53+
WebGPUProbDirectionGetter as ProbDirectionGetter,
54+
WebGPUPttDirectionGetter as PttDirectionGetter,
55+
WebGPUBootDirectionGetter as BootDirectionGetter,
56+
)
57+
else:
58+
raise ImportError(
59+
"No GPU backend available. Install either:\n"
60+
" - CUDA: pip install 'cuslines[cu13]' (NVIDIA GPU)\n"
61+
" - Metal: pip install 'cuslines[metal]' (Apple Silicon)\n"
62+
" - WebGPU: pip install 'cuslines[webgpu]' (cross-platform)"
63+
)
764

865
__all__ = [
966
"GPUTracker",
1067
"ProbDirectionGetter",
1168
"PttDirectionGetter",
12-
"BootDirectionGetter"
69+
"BootDirectionGetter",
70+
"BACKEND",
1371
]

cuslines/boot_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Shared utilities for bootstrap direction getters (CUDA and Metal).
2+
3+
Extracts DIPY model matrices (H, R, delta_b, delta_q, sampling_matrix)
4+
for OPDT and CSA models. Both backends need the same matrices — only
5+
the GPU dispatch differs.
6+
"""
7+
8+
from dipy.reconst import shm
9+
10+
11+
def prepare_opdt(gtab, sphere, sh_order_max=6, full_basis=False,
12+
sh_lambda=0.006, min_signal=1):
13+
"""Build bootstrap matrices for the OPDT model.
14+
15+
Returns dict with keys: model_type, min_signal, H, R, delta_b,
16+
delta_q, sampling_matrix, b0s_mask.
17+
"""
18+
sampling_matrix, _, _ = shm.real_sh_descoteaux(
19+
sh_order_max, sphere.theta, sphere.phi,
20+
full_basis=full_basis, legacy=True,
21+
)
22+
model = shm.OpdtModel(
23+
gtab, sh_order_max=sh_order_max, smooth=sh_lambda,
24+
min_signal=min_signal,
25+
)
26+
delta_b, delta_q = model._fit_matrix
27+
28+
H, R = _hat_and_lcr(gtab, model, sh_order_max)
29+
30+
return dict(
31+
model_type="OPDT", min_signal=min_signal,
32+
H=H, R=R, delta_b=delta_b, delta_q=delta_q,
33+
sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask,
34+
)
35+
36+
37+
def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False,
38+
sh_lambda=0.006, min_signal=1):
39+
"""Build bootstrap matrices for the CSA model.
40+
41+
Returns dict with keys: model_type, min_signal, H, R, delta_b,
42+
delta_q, sampling_matrix, b0s_mask.
43+
"""
44+
sampling_matrix, _, _ = shm.real_sh_descoteaux(
45+
sh_order_max, sphere.theta, sphere.phi,
46+
full_basis=full_basis, legacy=True,
47+
)
48+
model = shm.CsaOdfModel(
49+
gtab, sh_order_max=sh_order_max, smooth=sh_lambda,
50+
min_signal=min_signal,
51+
)
52+
delta_b = model._fit_matrix
53+
delta_q = model._fit_matrix
54+
55+
H, R = _hat_and_lcr(gtab, model, sh_order_max)
56+
57+
return dict(
58+
model_type="CSA", min_signal=min_signal,
59+
H=H, R=R, delta_b=delta_b, delta_q=delta_q,
60+
sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask,
61+
)
62+
63+
64+
def _hat_and_lcr(gtab, model, sh_order_max):
65+
"""Compute hat matrix H and leveraged centered residuals matrix R."""
66+
dwi_mask = ~gtab.b0s_mask
67+
x, y, z = model.gtab.gradients[dwi_mask].T
68+
_, theta, phi = shm.cart2sphere(x, y, z)
69+
B, _, _ = shm.real_sh_descoteaux(sh_order_max, theta, phi, legacy=True)
70+
H = shm.hat(B)
71+
R = shm.lcr_matrix(H)
72+
return H, R

0 commit comments

Comments
 (0)