Skip to content

[FEA] Add an option argument to raft::linalg::gemm for specifying compute type #2988

@vinaydes

Description

@vinaydes

Is your feature request related to a problem? Please describe.
Currently the cublas GEMM wrapper in RAFT does not have an option to specify which compute type should be used for the GEMM (cublasComputeType_t). The compute type is derived from the input and output types alone.

Describe the solution you'd like
Add an optional argument to raft::linalg::gemm for specifying compute type. We can default to the type returned by get_matmul_type if user does not specify anything. This can be achieved perhaps by overloading the GEMM call with and without compute type argument.

Describe alternatives you've considered
NA

Additional context
We needed this option while working on 1-NN distance computation. See https://github.com/vinaydes/cuvs/blob/distance-nn/cpp/src/distance/unfused_distance_nn.cuh#L164. Here, instead of default CUBLAS_COMPUTE_32F, we needed CUBLAS_COMPUTE_32F_FAST_TF32. Which is faster and adequate for the requirement.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions