@@ -55,6 +55,7 @@ NAMatMulKernel::NAMatMulKernel(NAMatMulKernelDescriptor descriptor, MTL::Device
5555 useBias = descriptor.useBias ;
5656 loadM = descriptor.loadM ;
5757 groupM = descriptor.groupM ;
58+ groupN = descriptor.groupN ;
5859
5960 // / The number of threads per group.
6061 source = createSource ();
@@ -146,6 +147,7 @@ inline uint2 morton_decode_rectangular_2d(uint code,
146147 source.SetValue (" BLOCK_DIMENSIONS_K_2" , std::to_string (blockDimensions[2 ] * 2 ));
147148 source.SetValue (" SPLIT_K" , std::to_string (splitK));
148149 source.SetValue (" GROUP_M" , std::to_string (groupM));
150+ source.SetValue (" GROUP_N" , std::to_string (groupN));
149151
150152 source += createConstants ();
151153
@@ -255,12 +257,12 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
255257 }
256258 if (transposed (' B' )) {
257259 source.SetValue (" B_SLICE" , std::to_string (blockDimensions[2 ]) + " , " + std::to_string (blockDimensions[1 ]));
258- source.SetValue (" B_MATRIX_SIZE" , " K, N " );
259- source.SetValue (" B_TILE_0_SIZE" , " 0, tgid.x * " + std::to_string (blockDimensions[ 1 ]) );
260- source.SetValue (" B_TILE_K1_SIZE" , " k, tgid.x * " + std::to_string (blockDimensions[ 1 ]) );
261- source.SetValue (" B_TILE_K2_SIZE" , " k + " + std::to_string (blockDimensions[2 ]) + " , tgid.x * " + std::to_string (blockDimensions[ 1 ]) );
262- source.SetValue (" B_TILE_LAST_K2_SIZE" , " K / " + std::to_string (blockDimensions[2 ] * 2 ) + " * " + std::to_string (blockDimensions[2 ] * 2 ) + " , tgid.x * " + std::to_string (blockDimensions[ 1 ]) );
263- source.SetValue (" B_TILE_LAST_K_SIZE" , " K / " + std::to_string (blockDimensions[2 ]) + " * " + std::to_string (blockDimensions[2 ]) + " , tgid.x * " + std::to_string (blockDimensions[ 1 ]) );
260+ source.SetValue (" B_MATRIX_SIZE" , " K, N_group_size " );
261+ source.SetValue (" B_TILE_0_SIZE" , " 0, N_group_offset " );
262+ source.SetValue (" B_TILE_K1_SIZE" , " k, N_group_offset " );
263+ source.SetValue (" B_TILE_K2_SIZE" , " k + " + std::to_string (blockDimensions[2 ]) + " , N_group_offset " );
264+ source.SetValue (" B_TILE_LAST_K2_SIZE" , " K / " + std::to_string (blockDimensions[2 ] * 2 ) + " * " + std::to_string (blockDimensions[2 ] * 2 ) + " , N_group_offset " );
265+ source.SetValue (" B_TILE_LAST_K_SIZE" , " K / " + std::to_string (blockDimensions[2 ]) + " * " + std::to_string (blockDimensions[2 ]) + " , N_group_offset " );
264266 source.SetValue (" B_RESIDUAL_SLICE" , " dynamic_extent, " + std::to_string (blockDimensions[1 ]));
265267 } else {
266268 source.SetValue (" B_SLICE" , std::to_string (blockDimensions[1 ]) + " , " + std::to_string (blockDimensions[2 ]));
@@ -333,6 +335,29 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
333335 const uint M_group_start = M_block_start;
334336 const uint M_group_offset = 0;
335337 const uint M_group_size = M - M_group_start;
338+ )" ;
339+ }
340+ if (transposed (' B' )) {
341+ if (groupN > 0 ) {
342+ source += R"(
343+ // Rebase transposed B to shared N-column groups for the same reason as
344+ // groupM: keep neighboring threadgroups on stable base pointers when N is
345+ // large without changing the global C layout.
346+ const uint N_block_start = tgid.x * {{BLOCK_DIMENSIONS_N}};
347+ const uint N_group_start = N_block_start / {{GROUP_N}} * {{GROUP_N}};
348+ const uint N_group_offset = N_block_start - N_group_start;
349+ const uint N_group_size = N - N_group_start;
350+ )" ;
351+ } else {
352+ source += R"(
353+ const uint N_block_start = tgid.x * {{BLOCK_DIMENSIONS_N}};
354+ const uint N_group_start = N_block_start;
355+ const uint N_group_offset = 0;
356+ const uint N_group_size = N - N_group_start;
357+ )" ;
358+ }
359+ source += R"(
360+ B_buf = B_buf + N_group_start * K;
336361)" ;
337362 }
338363 if (!transposed (' A' )) {
@@ -358,8 +383,16 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
358383 auto A = tensor<device {{MEMORY_NAME_A}}, dextents<int32_t, 2>, tensor_inline>(A_buf, dextents<int32_t, 2>({{A_MATRIX_SIZE}}));
359384)" ;
360385 }
361- source += R"(
386+ if (transposed (' B' )) {
387+ source += R"(
388+ auto B = tensor<device {{MEMORY_NAME_B}}, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>(K, N_group_size));
389+ )" ;
390+ } else {
391+ source += R"(
362392 auto B = tensor<device {{MEMORY_NAME_B}}, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>({{B_MATRIX_SIZE}}));
393+ )" ;
394+ }
395+ source += R"(
363396 auto C = tensor<device {{MEMORY_NAME_C}}, dextents<int32_t, 2>, tensor_inline>(C_buf, dextents<int32_t, 2>(N * {{SPLIT_K}}, M_group_size));
364397)" ;
365398 if (useBias) {
0 commit comments