diff --git a/Cargo.lock b/Cargo.lock index 28a0699d..e172ba64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1053,6 +1053,7 @@ dependencies = [ "arrow-buffer", "arrow-schema", "arrow-select", + "half", "indexmap", "numpy", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 212784fb..90e093f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ arrow-csv = { git = "https://github.com/kylebarron/arrow-rs", rev = "9875740e625 arrow-ipc = { git = "https://github.com/kylebarron/arrow-rs", rev = "9875740e625b78bfb0b6545eab63a17f47c6a122" } arrow-schema = { git = "https://github.com/kylebarron/arrow-rs", rev = "9875740e625b78bfb0b6545eab63a17f47c6a122" } arrow-select = { git = "https://github.com/kylebarron/arrow-rs", rev = "9875740e625b78bfb0b6545eab63a17f47c6a122" } +half = "2" indexmap = "2" # numpy = "0.21" # TODO: Pin to released version once NumPy 2.0 support is merged diff --git a/pyo3-arrow/Cargo.toml b/pyo3-arrow/Cargo.toml index 4dda19b4..bcd41e88 100644 --- a/pyo3-arrow/Cargo.toml +++ b/pyo3-arrow/Cargo.toml @@ -17,6 +17,7 @@ arrow-buffer = { workspace = true } arrow-schema = { workspace = true } arrow = { workspace = true, features = ["ffi"] } pyo3 = { workspace = true, features = ["abi3-py38", "indexmap"] } +half = { workspace = true } indexmap = { workspace = true } numpy = { workspace = true, features = ["half"] } thiserror = { workspace = true } diff --git a/pyo3-arrow/src/interop/numpy/from_numpy.rs b/pyo3-arrow/src/interop/numpy/from_numpy.rs index 7a534076..1a5571a1 100644 --- a/pyo3-arrow/src/interop/numpy/from_numpy.rs +++ b/pyo3-arrow/src/interop/numpy/from_numpy.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use arrow::datatypes::{ - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use numpy::{dtype_bound, PyArray1, PyArrayDescr, PyUntypedArray}; @@ -21,7 +21,9 @@ pub fn from_numpy(py: Python, array: &PyUntypedArray) -> PyArrowResult }}; } let dtype = array.dtype(); - if is_type::(py, dtype) { + if is_type::(py, dtype) { + numpy_to_arrow!(half::f16, Float16Type) + } else if is_type::(py, dtype) { numpy_to_arrow!(f32, Float32Type) } else if is_type::(py, dtype) { numpy_to_arrow!(f64, Float64Type)