Skip to content

Commit e3bd5f9

Browse files
jenniewCuiYifeng
andauthored
Add addmm, mm, bmm, baddbmm on SparseCsrXPU (#2758)
Add addmm, mm, bmm, baddbmm, support on SparseCsrXPU. Enable related tests. Related issue: #2211 #2213 --------- Co-authored-by: Cui, Yifeng <yifeng.cui@intel.com>
1 parent 205fa95 commit e3bd5f9

File tree

3 files changed

+348
-1
lines changed

3 files changed

+348
-1
lines changed

src/ATen/native/sparse/xpu/SparseCsrTensorMath.cpp

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
* http://www.apache.org/licenses/LICENSE-2.0
99
*/
1010

11+
#include <ATen/ExpandUtils.h>
1112
#include <ATen/SparseCsrTensorUtils.h>
13+
#include <ATen/TensorOperators.h>
14+
#include <ATen/native/Resize.h>
15+
#include <ATen/native/sparse/SparseCsrTensorMath.h>
1216
#include <ATen/native/sparse/SparseStubs.h>
1317
#include <ATen/native/sparse/xpu/sycl/SparseCsrTensorMathKernels.h>
1418
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
@@ -19,6 +23,11 @@
1923
#include <ATen/NativeFunctions.h>
2024
#else
2125
#include <ATen/ops/add.h>
26+
#include <ATen/ops/addmm.h>
27+
#include <ATen/ops/baddbmm.h>
28+
#include <ATen/ops/copy_native.h>
29+
#include <ATen/ops/mul.h>
30+
#include <ATen/ops/sparse_compressed_tensor.h>
2231
#endif
2332

2433
namespace at::native {
@@ -62,6 +71,318 @@ Tensor _sparse_csr_prod_xpu(
6271
input, dims_to_reduce, keepdim, dtype);
6372
}
6473

74+
Tensor addmm_calculation(
75+
const Tensor& input,
76+
const Tensor& mat1,
77+
const Tensor& mat2,
78+
const Scalar& beta,
79+
const Scalar& alpha) {
80+
Tensor mat1_dense = mat1.layout() != kStrided ? mat1.to_dense() : mat1;
81+
Tensor mat2_dense = mat2.layout() != kStrided ? mat2.to_dense() : mat2;
82+
83+
Tensor result_dense = mat1_dense.mm(mat2_dense) * alpha;
84+
if (beta.toComplexDouble() != 0.) {
85+
Tensor input_dense = input.layout() != kStrided ? input.to_dense() : input;
86+
result_dense.add_(input_dense, beta);
87+
}
88+
return result_dense;
89+
}
90+
91+
void addmm_out_sparse_csr(
92+
const Tensor& input,
93+
const Tensor& mat1,
94+
const Tensor& mat2,
95+
const Scalar& beta,
96+
const Scalar& alpha,
97+
Tensor& result) {
98+
TORCH_INTERNAL_ASSERT(
99+
!((mat1.layout() == kStrided) && (mat2.layout() == kStrided) &&
100+
(result.layout() == kStrided)),
101+
"Expected at least one sparse input");
102+
103+
// Layout checks are nested mat1, mat2, result
104+
// Conditions are ordered strided, csr, csc, bsr, bsc.
105+
// Valid combinations terminate in a return
106+
// Invalid combinations are omitted and will fall though to the TORCH check
107+
// generating an informative error message
108+
109+
if ((mat1.layout() == kSparseBsr) && (mat2.layout() == kStrided) &&
110+
(result.layout() == kStrided)) {
111+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
112+
result.copy_(result_dense);
113+
return;
114+
}
115+
116+
if ((mat1.layout() == kStrided) && (mat2.layout() == kSparseBsc) &&
117+
(result.layout() == kStrided)) {
118+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
119+
result.copy_(result_dense);
120+
return;
121+
}
122+
123+
if (mat1.layout() == kStrided) {
124+
if ((mat2.layout() == kSparseCsr) && (result.layout() == kStrided)) {
125+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
126+
result.copy_(result_dense);
127+
return;
128+
}
129+
if ((mat2.layout() == kSparseCsc) && (result.layout() == kStrided)) {
130+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
131+
result.copy_(result_dense);
132+
return;
133+
}
134+
}
135+
if (mat1.layout() == kSparseCsr) {
136+
if ((mat2.layout() == kStrided) && (result.layout() == kStrided)) {
137+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
138+
result.copy_(result_dense);
139+
return;
140+
}
141+
if ((mat2.layout() == kSparseCsr) && (result.layout() == kSparseCsr)) {
142+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
143+
result = result_dense.to_sparse_csr();
144+
return;
145+
}
146+
if ((mat2.layout() == kSparseCsc) && (result.layout() == kSparseCsr)) {
147+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
148+
result = result_dense.to_sparse_csr();
149+
return;
150+
}
151+
}
152+
if (mat1.layout() == kSparseCsc) {
153+
if ((mat2.layout() == kStrided) && (result.layout() == kStrided)) {
154+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
155+
result.copy_(result_dense);
156+
return;
157+
}
158+
if ((mat2.layout() == kSparseCsr) && (result.layout() == kSparseCsr)) {
159+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
160+
result = result_dense.to_sparse_csr();
161+
return;
162+
}
163+
if (mat2.layout() == kSparseCsc) {
164+
if (result.layout() == kSparseCsr) {
165+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
166+
result = result_dense.to_sparse_csr();
167+
return;
168+
}
169+
if (result.layout() == kSparseCsc) {
170+
Tensor result_dense = addmm_calculation(input, mat1, mat2, beta, alpha);
171+
result = result_dense.to_sparse_csc();
172+
return;
173+
}
174+
}
175+
}
176+
TORCH_CHECK(
177+
false,
178+
"addmm: computation on XPU is not implemented for ",
179+
result.layout(),
180+
" + ",
181+
mat1.layout(),
182+
" @ ",
183+
mat2.layout());
184+
}
185+
186+
// result = beta * self + alpha * (mat1 @ mat2)
187+
Tensor& addmm_out_sparse_compressed_xpu(
188+
const Tensor& self,
189+
const Tensor& mat1,
190+
const Tensor& mat2,
191+
const Scalar& beta,
192+
const Scalar& alpha,
193+
Tensor& result) {
194+
TORCH_CHECK(
195+
self.is_xpu(),
196+
"Expected all tensors to be on the same device. addmm expected self to be XPU tensor, but got ",
197+
self.device(),
198+
" tensor");
199+
TORCH_CHECK(
200+
mat1.is_xpu(),
201+
"Expected all tensors to be on the same device. addmm expected mat1 to be XPU tensor, but got ",
202+
mat1.device(),
203+
" tensor");
204+
TORCH_CHECK(
205+
mat2.is_xpu(),
206+
"Expected all tensors to be on the same device. addmm expected mat2 to be XPU tensor, but got ",
207+
mat2.device(),
208+
" tensor");
209+
TORCH_CHECK(
210+
result.is_xpu(),
211+
"Expected all tensors to be on the same device. addmm expected result to be XPU tensor, but got ",
212+
result.device(),
213+
" tensor");
214+
215+
// Same checks as in TORCH_META_FUNC(addmm) at
216+
// aten/src/ATen/native/LinearAlgebra.cpp
217+
sparse::impl::_check_dim(mat1, 2, "mat1");
218+
sparse::impl::_check_dim(mat2, 2, "mat2");
219+
220+
TORCH_CHECK(
221+
mat1.size(1) == mat2.size(0),
222+
"mat1 and mat2 shapes cannot be multiplied (",
223+
mat1.size(0),
224+
"x",
225+
mat1.size(1),
226+
" and ",
227+
mat2.sizes()[0],
228+
"x",
229+
mat2.sizes()[1],
230+
")");
231+
232+
c10::MaybeOwned<at::Tensor> self_;
233+
// Don't expand self if this is an in-place operation
234+
if (&result == &self) {
235+
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
236+
} else {
237+
self_ = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
238+
}
239+
240+
sparse::impl::_check_dim(*self_, 2, "self");
241+
TORCH_CHECK(
242+
((self_->dim() == 2) && (self_->size(0) == mat1.size(0)) &&
243+
(self_->size(1) == mat2.size(1))),
244+
"The input tensor must be a matrix with size ",
245+
mat1.size(0),
246+
"x",
247+
mat2.size(1),
248+
", but got a ",
249+
self_->dim(),
250+
"-D tensor with size ",
251+
self_->size(0),
252+
"x",
253+
self_->size(1));
254+
255+
if (!result.is_same(self)) {
256+
if (result.layout() == kStrided) {
257+
at::native::resize_output(result, self_->sizes());
258+
} else {
259+
result.resize_as_sparse_(*self_);
260+
}
261+
}
262+
263+
if (result.numel() == 0) {
264+
return result;
265+
}
266+
267+
if (sparse::impl::_is_sparse_and_zero(mat1) ||
268+
sparse::impl::_is_sparse_and_zero(mat2)) {
269+
// According to docs, when beta==0 values in self should be ignored.
270+
// nans and infs should not propagate
271+
const auto beta_val = beta.toComplexDouble();
272+
if (beta_val == 0.) {
273+
result.zero_();
274+
} else {
275+
if (!result.is_same(self)) {
276+
result.copy_(*self_);
277+
}
278+
if (beta_val != 1.) {
279+
result.mul_(beta);
280+
}
281+
}
282+
return result;
283+
}
284+
285+
addmm_out_sparse_csr(*self_, mat1, mat2, beta, alpha, result);
286+
return result;
287+
}
288+
289+
Tensor expand_batch_if_necessary(const Tensor& mat) {
290+
auto indice_batch_ndim = sparse_csr::numBatchDimensions(mat);
291+
auto [compressed_indices, plain_indices] =
292+
sparse_csr::getCompressedPlainIndices(mat);
293+
auto values = mat.values();
294+
auto batch_diff_size = mat.sizes().vec();
295+
auto real_batch_ndim = mat.sizes().size() - 2;
296+
if (indice_batch_ndim < real_batch_ndim) {
297+
batch_diff_size.erase(
298+
batch_diff_size.begin() + (real_batch_ndim - indice_batch_ndim),
299+
batch_diff_size.end());
300+
auto reshaped_compressed_indices_shape = compressed_indices.sizes().vec();
301+
reshaped_compressed_indices_shape.insert(
302+
std::begin(reshaped_compressed_indices_shape),
303+
std::begin(batch_diff_size),
304+
std::end(batch_diff_size));
305+
compressed_indices =
306+
compressed_indices.expand(reshaped_compressed_indices_shape);
307+
auto reshaped_plain_indices_shape = plain_indices.sizes().vec();
308+
reshaped_plain_indices_shape.insert(
309+
reshaped_plain_indices_shape.begin(),
310+
batch_diff_size.begin(),
311+
batch_diff_size.end());
312+
plain_indices = plain_indices.expand(reshaped_plain_indices_shape);
313+
auto reshaped_values_indices_shape = values.sizes().vec();
314+
reshaped_values_indices_shape.insert(
315+
reshaped_values_indices_shape.begin(),
316+
batch_diff_size.begin(),
317+
batch_diff_size.end());
318+
values = values.expand(reshaped_values_indices_shape);
319+
}
320+
auto updated_sparse_tensor = at::sparse_compressed_tensor(
321+
compressed_indices, plain_indices, values, mat.sizes(), mat.options());
322+
return updated_sparse_tensor;
323+
}
324+
325+
Tensor& baddbmm_out_sparse_csr_xpu(
326+
const Tensor& self,
327+
const Tensor& mat1,
328+
const Tensor& mat2,
329+
const Scalar& beta,
330+
const Scalar& alpha,
331+
Tensor& result) {
332+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.is_sparse_csr());
333+
334+
TORCH_CHECK(
335+
self.layout() == kStrided,
336+
"torch.baddbmm: Expected self to be strided, but got layout ",
337+
self.layout());
338+
TORCH_CHECK(
339+
mat2.layout() == kStrided,
340+
"torch.baddbmm: Expect mat2 to be strided, but got ",
341+
mat2.layout());
342+
TORCH_CHECK(
343+
result.layout() == kStrided,
344+
"torch.baddbmm: Expect result to be strided, but got ",
345+
result.layout());
346+
347+
if (!result.is_same(self)) {
348+
at::native::resize_output(result, self.sizes());
349+
}
350+
351+
if (mat1._nnz() == 0) {
352+
// According to docs, when beta==0 values in self should be ignored
353+
// nans and infs should not propagate
354+
if (beta.toComplexDouble() == 0.) {
355+
result.zero_();
356+
} else {
357+
if (!result.is_same(self)) {
358+
result.copy_(self);
359+
}
360+
if (beta.toComplexDouble() != 1.) {
361+
result.mul_(beta);
362+
}
363+
}
364+
return result;
365+
}
366+
367+
// broadcast batch of sparse indices and values if not compatible with sizes
368+
// before to_dense() to_dense issue:
369+
// https://github.com/intel/torch-xpu-ops/issues/2801
370+
auto mat1_new = expand_batch_if_necessary(mat1);
371+
372+
at::baddbmm_out(result, self, mat1_new.to_dense(), mat2, beta, alpha);
373+
return result;
374+
}
375+
376+
Tensor& bmm_out_sparse_csr_xpu(
377+
const Tensor& mat1,
378+
const Tensor& mat2,
379+
Tensor& result) {
380+
Scalar beta(0.0);
381+
Scalar alpha(1.0);
382+
return at::native::baddbmm_out_sparse_csr_xpu(
383+
result, mat1, mat2, beta, alpha, result);
384+
}
385+
65386
Tensor& add_out_sparse_compressed_xpu(
66387
const Tensor& self,
67388
const SparseCsrTensor& other,

test/xpu/test_sparse_csr_xpu.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
skipCUDAIfNoSparseGeneric,
4747
skipCUDAIfRocm,
4848
skipMeta,
49+
tol,
50+
toleranceOverride,
4951
)
5052
from torch.testing._internal.common_dtype import (
5153
all_types_and_complex,
@@ -2142,6 +2144,7 @@ def test_csr_matvec(self, device, dtype):
21422144

21432145
@onlyOn(["cuda", "xpu"])
21442146
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2147+
@precisionOverride({torch.float64: 2e-6})
21452148
def test_baddbmm(self, device, dtype):
21462149
# TODO: disable the invariant checks within torch.baddbmm that
21472150
# constructs unconventional csr tensors leading to
@@ -2803,7 +2806,9 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
28032806
)
28042807
)
28052808
@dtypesIfXPU(*floating_and_complex_types_and(torch.half, torch.bfloat16))
2806-
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
2809+
@precisionOverride(
2810+
{torch.bfloat16: 3.5e-2, torch.float16: 1e-2, torch.float64: 2e-6}
2811+
)
28072812
def test_sparse_addmm(self, device, dtype):
28082813
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
28092814
if alpha_beta is None:
@@ -2845,6 +2850,7 @@ def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
28452850
torch.cdouble: 1e-8,
28462851
}
28472852
)
2853+
@toleranceOverride({torch.double: tol(atol=2e-6, rtol=1e-6)})
28482854
@dtypesIfCUDA(
28492855
*floating_types_and(
28502856
torch.complex64,
@@ -2987,6 +2993,7 @@ def maybe_transpose(cond, m):
29872993
torch.cdouble: 1e-8,
29882994
}
29892995
)
2996+
@toleranceOverride({torch.double: tol(atol=2e-6, rtol=1e-6)})
29902997
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
29912998
M = torch.randn(n, m, device=device).to(dtype)
29922999
m1 = torch.randn(n, k, device=device).to(dtype)

0 commit comments

Comments
 (0)