-
Notifications
You must be signed in to change notification settings - Fork 953
Description
Some backstory on JAX PyCapsule requirements
I have been digging into JAX and subsequently wondering about PyO3 compatibility. To extend JAX in Rust (like in the C++ extending-jax repo, but for Rust) you have to pass Rust functions across to JAX in Python land (it either goes via ctypes or off to C++ via library like nanobind: we are interested in the ctypes case).
This then requires a "capsule", which I hadn't encountered before (though I have used PyO3 a fair amount now).
- JAX docs: jax.ffi.pycapsule
The existing project developing this was segfaulting due to trying to pass its function pointer f directly (rather than a pointer to the function pointer f), which I resolved by instead going through PyCapsule_New. That is, I only managed to meet the requirement for this JAX FFI setup by bypassing the pyo3 PyCapsule object-oriented method API and going via pyo3-ffi.
I'm pretty sure this was a hard requirement and not my mistake (but apologies if I'm wrong). The Send trait required by the value you put into the ::new method call on the capsule means you can't have a *void pointer, at least as far as I can work out.
- I think technically you say "a function pointer to void" --
*const void(*)()or*mut void(*)()- This represents a pointer to a function pointer that returns void and takes no arguments, used to pass opaque function pointers through Python code, allowing C extension modules (which may even be FFI'd from Rust) to share function pointers across different modules.
I'm not entirely sure if this is infeasible due to a function pointer implementing Send or if I'm just overlooking how to do it. If infeasible then I presume this is not by design, and perhaps ought be added via another constructor method?
I see other libs like kornia and pyoxidizer also 'going around' the main PyCapsule constructors. I suspect there should actually be a helper here on PyCapsule itself.
Under the hood in JAX (in _src/ffi.py), it needs the pointer to refer to the function itself, we cannot wrap it in a struct that implements Send. It must be the function pointer itself. I've tried to explain specifically why the best I can below.
There are 2 parts under the hood in jax (the jaxlib package) which will use the capsule:
- the XLA client (Python registration) https://github.com/jax-ml/jax/blob/4efd7828b041f4d1a9cdd8b5c61a31cda378414a/jaxlib/xla_client.py#L406
- this receives
PyCapsules from external code (like our extension) and registers them with JAX/XLA
- this receives
- the FFI (C++ registration) https://github.com/jax-ml/jax/blob/4efd7828b041f4d1a9cdd8b5c61a31cda378414a/jaxlib/ffi.cc#L142
- this handles the actual registration of the custom call target by unwrapping the capsule passed in thru Python.
The unwrapping step restricts what our capsule can look like.
Capsule unwrapping in FFI code (and PyO3 limitations therein)
This is the part where PyO3 comes in. To briefly follow what happens after the capsule arrives in ffi.cc:
There's a PyRegisterCustomCallTarget function, with two paths, a legacy untyped API api_version=0 and a new typed API api_version=1:
if (api_version == 0) {
nb::capsule capsule;
if (!nb::try_cast<nb::capsule>(fn, capsule)) {
return absl::InvalidArgumentError(...);
}
xla::CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule.data()), platform);
return absl::OkStatus();
}- Here
capsule.data()extracts whatever pointer you stored in the capsule, casts it tovoid*, and hands it to XLA's global registry. The registry stores that raw pointer, expecting it to be the function pointer itself (not a pointer to a pointer).
if (api_version == 1) {
nb::capsule capsule;
if (nb::try_cast<nb::capsule>(fn, capsule)) {
return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler(
xla::ffi::GetXlaFfiApi(), fn_name, platform,
reinterpret_cast<XLA_FFI_Handler*>(
static_cast<void*>(capsule.data()))));
}
// ...
}- Same idea:
capsule.data()is reinterpreted directly as anXLA_FFI_Handler*function pointer.
In both arms the C++ code does
static_cast<void*>(capsule.data())then either stores it directly or casts it to a function pointer type. This means the capsule must contain the function pointer value itself, not a pointer to a struct containing the function pointer.
If I wrapped a function pointer in a struct like:
struct Wrapper {
f: extern "C" fn(...) -> ...
}Then capsule.data() would return a pointer to Wrapper, which when cast to a function pointer would be wrong (giving my struct's address, not my function).
PyO3's PyCapsule::new has this signature:
pub fn new<T: 'static + Send + AssertNotZeroSized>(
py: Python<'_>,
value: T,
name: Option<CString>,
) -> PyResult<Bound<'_, Self>>Function pointers implement Send by default in Rust (source),
The following traits are implemented for function pointers with any number of arguments and any ABI. These traits have implementations that are automatically generated by the compiler, so are not limited by missing language features: ... [list including
Send]
so its trait bounds are fine.
I think the issue is that PyCapsule::new stores a value and gives a pointer to that storage. But XLA expects the capsule's data pointer to be the function pointer value itself. Semantically that follows to me as other uses I've seen are passing 'nouns' (state, data) whereas I am passing a function, a callback to run some computation (a 'verb' rather than a 'noun').
So the T gets moved into the capsule and stored by value, and PyO3 gives us a pointer to where it stored T... which is fine for values, but XLA doesn't want a pointer to where we stored the function pointer - it wants the function pointer's value itself as the capsule's data pointer.
A workaround via pyo3-ffi
This was what led to this approach to bypass PyO3's safe API:
use pyo3::ffi::PyCapsule_New;
use std::ffi::CString;
use std::ptr;
extern "C" fn my_handler(...) -> ... { ... }
let name = CString::new("my_handler").unwrap();
let capsule = unsafe {
PyCapsule_New(
my_handler as *mut std::ffi::c_void, // the function pointer AS the data
name.as_ptr(),
None, // no destructor needed
)
};This puts the function pointer's value directly as the capsule's internal pointer, which is exactly what capsule.data() will retrieve on the C++ side.
Gap in PyO3?
I think this maybe warrants a lower-level constructor like:
pub unsafe fn new_raw(
py: Python<'_>,
pointer: *mut c_void,
name: Option<&CStr>,
destructor: Option<...>
) -> PyResult<&PyCapsule>The current API assumes you want to store a value and get a pointer to it, but PyCapsules are also used to pass opaque pointers (like function pointers) where the pointer is the value. I'm not sure about those examples I linked above in kornia/pyoxidizer but the FFI interop case for XLA is probably common enough to warrant direct support, if it's not too much trouble to extend the API in this way.
There's every chance that the answer could just be "you just have to use PyCapsule_New directly" but if this is a gap I thought it best I phone it in.
It seems better ergonomics to do this:
#[pyfunction(name = "rms_norm")]
fn rms_norm_jax(py: Python<'_>) -> PyResult<Bound<'_, PyCapsule>> {
unsafe {
PyCapsule::new_pointer(py, ffi::RmsNorm as *mut c_void, None)
}
}than this:
#[pyfunction(name = "rms_norm")]
fn rms_norm_jax(py: Python<'_>) -> PyResult<Bound<'_, PyCapsule>> {
let fn_ptr: *mut c_void = ffi::RmsNorm as *mut c_void;
let name = std::ptr::null();
unsafe {
let capsule = pyo3::ffi::PyCapsule_New(fn_ptr, name, None);
if capsule.is_null() {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"Failed to create PyCapsule",
));
}
let any: Bound<'_, PyAny> = Bound::from_owned_ptr(py, capsule);
Ok(any.cast_into_unchecked::<PyCapsule>())
}
}- PR in which this scenario arises: fix(lib): runtime segfault and NumPy OOB jeertmans/extending-jax#3
- @jeertmans discussed here previously in How to re-export C/C++ function pointers with PyCapsule? #4772
- I checked and both work, the first looks of course much neater!
- 🚢 Submitted as PR in feat: raw pointer capsule instantiation method #5689
Please let me know if I can anything explain anything further, thanks for reading 🙂