-
Notifications
You must be signed in to change notification settings - Fork 227
Description
Is your feature request related to a problem? Please describe.
Currently the cublas GEMM wrapper in RAFT does not have an option to specify which compute type should be used for the GEMM (cublasComputeType_t). The compute type is derived from the input and output types alone.
Describe the solution you'd like
Add an optional argument to raft::linalg::gemm for specifying compute type. We can default to the type returned by get_matmul_type if user does not specify anything. This can be achieved perhaps by overloading the GEMM call with and without compute type argument.
Describe alternatives you've considered
NA
Additional context
We needed this option while working on 1-NN distance computation. See https://github.com/vinaydes/cuvs/blob/distance-nn/cpp/src/distance/unfused_distance_nn.cuh#L164. Here, instead of default CUBLAS_COMPUTE_32F, we needed CUBLAS_COMPUTE_32F_FAST_TF32. Which is faster and adequate for the requirement.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status