Skip to content

Commit cf4ccc0

Browse files
committed
Do groupN too.
1 parent 6b11ecf commit cf4ccc0

File tree

5 files changed

+57
-11
lines changed

5 files changed

+57
-11
lines changed

lib/nnc/mfa/kernels/NAMatMulDescriptor.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ static uint32_t groupM(const uint32_t M) noexcept {
1414
return (M >= 4096) ? 4096 : 0;
1515
}
1616

17+
static uint32_t groupN(const uint32_t N) noexcept {
18+
return (N >= 4096) ? 4096 : 0;
19+
}
20+
1721
bool NAMatMulDescriptor::operator==(const NAMatMulDescriptor& rhs) const {
1822
auto lhsMatrixDimensions = matrixDimensions;
1923
auto rhsMatrixDimensions = rhs.matrixDimensions;
@@ -201,7 +205,8 @@ std::pair<NAMatMulKernelDescriptor, PipelineValue<NAMatMulKernel> *> NAMatMulDes
201205

202206
uint16_t splitK = this->splitK();
203207
const uint32_t groupMValue = groupM(this->matrixDimensions[0]);
204-
auto kernelDesc = NAMatMulKernelDescriptor(simd::ushort3 { 128, 64, 64 }, this->memoryPrecisions, registerPrecisions, splitK, 4, this->transposeState, this->useBias, this->loadM, groupMValue);
208+
const uint32_t groupNValue = this->transposeState[1] ? groupN(this->matrixDimensions[1]) : 0;
209+
auto kernelDesc = NAMatMulKernelDescriptor(simd::ushort3 { 128, 64, 64 }, this->memoryPrecisions, registerPrecisions, splitK, 4, this->transposeState, this->useBias, this->loadM, groupMValue, groupNValue);
205210
NAMatMulKernel* kernel = createKernel(kernelDesc);
206211
auto pipelines = createPipeline(kernel->library.get(), splitK, (this->matrixDimensions[1] % 2) == 0);
207212

lib/nnc/mfa/kernels/NAMatMulKernel.cpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ NAMatMulKernel::NAMatMulKernel(NAMatMulKernelDescriptor descriptor, MTL::Device
5555
useBias = descriptor.useBias;
5656
loadM = descriptor.loadM;
5757
groupM = descriptor.groupM;
58+
groupN = descriptor.groupN;
5859

5960
/// The number of threads per group.
6061
source = createSource();
@@ -146,6 +147,7 @@ inline uint2 morton_decode_rectangular_2d(uint code,
146147
source.SetValue("BLOCK_DIMENSIONS_K_2", std::to_string(blockDimensions[2] * 2));
147148
source.SetValue("SPLIT_K", std::to_string(splitK));
148149
source.SetValue("GROUP_M", std::to_string(groupM));
150+
source.SetValue("GROUP_N", std::to_string(groupN));
149151

150152
source += createConstants();
151153

@@ -255,12 +257,12 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
255257
}
256258
if (transposed('B')) {
257259
source.SetValue("B_SLICE", std::to_string(blockDimensions[2]) + ", " + std::to_string(blockDimensions[1]));
258-
source.SetValue("B_MATRIX_SIZE", "K, N");
259-
source.SetValue("B_TILE_0_SIZE", "0, tgid.x * " + std::to_string(blockDimensions[1]));
260-
source.SetValue("B_TILE_K1_SIZE", "k, tgid.x * " + std::to_string(blockDimensions[1]));
261-
source.SetValue("B_TILE_K2_SIZE", "k + " + std::to_string(blockDimensions[2]) + ", tgid.x * " + std::to_string(blockDimensions[1]));
262-
source.SetValue("B_TILE_LAST_K2_SIZE", "K / " + std::to_string(blockDimensions[2] * 2) + " * " + std::to_string(blockDimensions[2] * 2) + ", tgid.x * " + std::to_string(blockDimensions[1]));
263-
source.SetValue("B_TILE_LAST_K_SIZE", "K / " + std::to_string(blockDimensions[2]) + " * " + std::to_string(blockDimensions[2]) + ", tgid.x * " + std::to_string(blockDimensions[1]));
260+
source.SetValue("B_MATRIX_SIZE", "K, N_group_size");
261+
source.SetValue("B_TILE_0_SIZE", "0, N_group_offset");
262+
source.SetValue("B_TILE_K1_SIZE", "k, N_group_offset");
263+
source.SetValue("B_TILE_K2_SIZE", "k + " + std::to_string(blockDimensions[2]) + ", N_group_offset");
264+
source.SetValue("B_TILE_LAST_K2_SIZE", "K / " + std::to_string(blockDimensions[2] * 2) + " * " + std::to_string(blockDimensions[2] * 2) + ", N_group_offset");
265+
source.SetValue("B_TILE_LAST_K_SIZE", "K / " + std::to_string(blockDimensions[2]) + " * " + std::to_string(blockDimensions[2]) + ", N_group_offset");
264266
source.SetValue("B_RESIDUAL_SLICE", "dynamic_extent, " + std::to_string(blockDimensions[1]));
265267
} else {
266268
source.SetValue("B_SLICE", std::to_string(blockDimensions[1]) + ", " + std::to_string(blockDimensions[2]));
@@ -333,6 +335,29 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
333335
const uint M_group_start = M_block_start;
334336
const uint M_group_offset = 0;
335337
const uint M_group_size = M - M_group_start;
338+
)";
339+
}
340+
if (transposed('B')) {
341+
if (groupN > 0) {
342+
source += R"(
343+
// Rebase transposed B to shared N-column groups for the same reason as
344+
// groupM: keep neighboring threadgroups on stable base pointers when N is
345+
// large without changing the global C layout.
346+
const uint N_block_start = tgid.x * {{BLOCK_DIMENSIONS_N}};
347+
const uint N_group_start = N_block_start / {{GROUP_N}} * {{GROUP_N}};
348+
const uint N_group_offset = N_block_start - N_group_start;
349+
const uint N_group_size = N - N_group_start;
350+
)";
351+
} else {
352+
source += R"(
353+
const uint N_block_start = tgid.x * {{BLOCK_DIMENSIONS_N}};
354+
const uint N_group_start = N_block_start;
355+
const uint N_group_offset = 0;
356+
const uint N_group_size = N - N_group_start;
357+
)";
358+
}
359+
source += R"(
360+
B_buf = B_buf + N_group_start * K;
336361
)";
337362
}
338363
if (!transposed('A')) {
@@ -358,8 +383,16 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
358383
auto A = tensor<device {{MEMORY_NAME_A}}, dextents<int32_t, 2>, tensor_inline>(A_buf, dextents<int32_t, 2>({{A_MATRIX_SIZE}}));
359384
)";
360385
}
361-
source += R"(
386+
if (transposed('B')) {
387+
source += R"(
388+
auto B = tensor<device {{MEMORY_NAME_B}}, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>(K, N_group_size));
389+
)";
390+
} else {
391+
source += R"(
362392
auto B = tensor<device {{MEMORY_NAME_B}}, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>({{B_MATRIX_SIZE}}));
393+
)";
394+
}
395+
source += R"(
363396
auto C = tensor<device {{MEMORY_NAME_C}}, dextents<int32_t, 2>, tensor_inline>(C_buf, dextents<int32_t, 2>(N * {{SPLIT_K}}, M_group_size));
364397
)";
365398
if (useBias) {

lib/nnc/mfa/kernels/NAMatMulKernel.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ struct NAMatMulKernel {
3838

3939
uint32_t groupM;
4040

41+
uint32_t groupN;
42+
4143
/// The number of threads per group.
4244
uint16_t threadgroupSize(MTL::ComputePipelineState *const pipelineState, const NAMatMulDescriptor &descriptor) const noexcept;
4345

lib/nnc/mfa/kernels/NAMatMulKernelDescriptor.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ bool NAMatMulKernelDescriptor::operator==(const NAMatMulKernelDescriptor& rhs) c
1414
simd_all(transposeState == rhs.transposeState) &&
1515
(useBias == rhs.useBias) &&
1616
(loadM == rhs.loadM) &&
17-
(groupM == rhs.groupM);
17+
(groupM == rhs.groupM) &&
18+
(groupN == rhs.groupN);
1819
}
1920

2021
std::size_t std::hash<NAMatMulKernelDescriptor>::operator()(const NAMatMulKernelDescriptor& hash) const noexcept {
@@ -27,12 +28,13 @@ std::size_t std::hash<NAMatMulKernelDescriptor>::operator()(const NAMatMulKernel
2728
combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], hash.useBias }));
2829
combine_32(seed, pack_32(simd::uchar4 { hash.loadM, 0, 0, 0 }));
2930
combine_32(seed, hash.groupM);
31+
combine_32(seed, hash.groupN);
3032
return seed;
3133
}
3234

3335
// MARK: - Initializer
3436

35-
NAMatMulKernelDescriptor::NAMatMulKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, GEMMOperandPrecisions registerPrecisions, uint16_t splitK, uint16_t executionSIMDGroups, simd::uchar3 transposeState, bool useBias, bool loadM, uint32_t groupM) noexcept {
37+
NAMatMulKernelDescriptor::NAMatMulKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, GEMMOperandPrecisions registerPrecisions, uint16_t splitK, uint16_t executionSIMDGroups, simd::uchar3 transposeState, bool useBias, bool loadM, uint32_t groupM, uint32_t groupN) noexcept {
3638
this->blockDimensions = blockDimensions;
3739
this->memoryPrecisions = memoryPrecisions;
3840
this->registerPrecisions = registerPrecisions;
@@ -42,4 +44,5 @@ NAMatMulKernelDescriptor::NAMatMulKernelDescriptor(simd::ushort3 blockDimensions
4244
this->useBias = useBias;
4345
this->loadM = loadM;
4446
this->groupM = groupM;
47+
this->groupN = groupN;
4548
}

lib/nnc/mfa/kernels/NAMatMulKernelDescriptor.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,15 @@ struct NAMatMulKernelDescriptor {
180180
/// Rebase A / C to shared M-row groups. 0 disables grouping.
181181
uint32_t groupM;
182182

183+
/// Rebase transposed B to shared N-column groups. 0 disables grouping.
184+
uint32_t groupN;
185+
183186
// MARK: - Functionality from GEMMDescriptor
184187

185188
NAMatMulKernelDescriptor() = delete;
186189

187190
/// Initialize the kernel descriptor.
188-
NAMatMulKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, GEMMOperandPrecisions registerPrecisions, uint16_t splitK, uint16_t executionSIMDGroups, simd::uchar3 transposeState, bool useBias, bool loadM, uint32_t groupM) noexcept;
191+
NAMatMulKernelDescriptor(simd::ushort3 blockDimensions, GEMMOperandPrecisions memoryPrecisions, GEMMOperandPrecisions registerPrecisions, uint16_t splitK, uint16_t executionSIMDGroups, simd::uchar3 transposeState, bool useBias, bool loadM, uint32_t groupM, uint32_t groupN) noexcept;
189192

190193
bool operator==(const NAMatMulKernelDescriptor& rhs) const;
191194
};

0 commit comments

Comments
 (0)