Skip to content

[NPU] Add fused_linear_cross_entropy operator#1164

Open
lowdy1 wants to merge 1 commit intolinkedin:mainfrom
lowdy1:fused_ce
Open

[NPU] Add fused_linear_cross_entropy operator#1164
lowdy1 wants to merge 1 commit intolinkedin:mainfrom
lowdy1:fused_ce

Conversation

@lowdy1
Copy link
Contributor

@lowdy1 lowdy1 commented Mar 25, 2026

Summary

To address the UB overflow issue observed in the benchmark, we introduced an operator with an NPU-friendly implementation of fused linear cross entropy. This fused operator relies on several underlying operations (e.g., large matrix multiplication, softmax, and cross entropy), so its current benchmark performance is not yet optimal. Further optimization may be needed.

Testing Done

Device: Atlas A2
python -m pytest ./test/transformers/test_fused_linear_cross_entropy.py
image

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@lowdy1
Copy link
Contributor Author

lowdy1 commented Mar 25, 2026

**************************************
     BENCHMARKING SPEED for FUSED_LINEAR_CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1167.7342529296875,
      1417.60009765625,
      1970.027587890625
    ],
    "y_values_20": [
      1167.7342529296875,
      1417.60009765625,
      1970.027587890625
    ],
    "y_values_80": [
      1167.7342529296875,
      1417.60009765625,
      1970.027587890625
    ],
    "timestamp": "2026-03-25 06:53:45",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1072.7669677734375,
      1327.0628662109375,
      1879.5020751953125
    ],
    "y_values_20": [
      1072.7669677734375,
      1327.0628662109375,
      1879.5020751953125
    ],
    "y_values_80": [
      1072.7669677734375,
      1327.0628662109375,
      1879.5020751953125
    ],
    "timestamp": "2026-03-25 06:54:34",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      23.48188018798828,
      47.42369079589844,
      94.18435668945312
    ],
    "y_values_20": [
      23.210458755493164,
      47.168094635009766,
      94.18435668945312
    ],
    "y_values_80": [
      23.85506820678711,
      47.67928695678711,
      94.18435668945312
    ],
    "timestamp": "2026-03-25 06:54:50",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      5.7671799659729,
      5.85099983215332,
      6.026599884033203
    ],
    "y_values_20": [
      5.755055904388428,
      5.825039863586426,
      6.023431777954102
    ],
    "y_values_80": [
      5.777199745178223,
      5.862380027770996,
      6.030032157897949
    ],
    "timestamp": "2026-03-25 06:55:10",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      5.715149879455566,
      5.824970245361328,
      6.033420085906982
    ],
    "y_values_20": [
      5.710140228271484,
      5.8231401443481445,
      6.029191970825195
    ],
    "y_values_80": [
      5.7175798416137695,
      5.832339763641357,
      6.040832042694092
    ],
    "timestamp": "2026-03-25 06:55:29",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      74.61641693115234,
      127.07272338867188,
      235.3399200439453
    ],
    "y_values_20": [
      74.61641693115234,
      127.07272338867188,
      235.3399200439453
    ],
    "y_values_80": [
      74.61641693115234,
      127.07272338867188,
      235.3399200439453
    ],
    "timestamp": "2026-03-25 06:55:48",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1178.111083984375,
      1423.9302978515625,
      1976.8585205078125
    ],
    "y_values_20": [
      1178.111083984375,
      1423.9302978515625,
      1976.8585205078125
    ],
    "y_values_80": [
      1178.111083984375,
      1423.9302978515625,
      1976.8585205078125
    ],
    "timestamp": "2026-03-25 06:56:39",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1081.96533203125,
      1332.771484375,
      1886.146484375
    ],
    "y_values_20": [
      1081.96533203125,
      1332.771484375,
      1886.146484375
    ],
    "y_values_80": [
      1081.96533203125,
      1332.771484375,
      1886.146484375
    ],
    "timestamp": "2026-03-25 06:57:28",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      97.71183776855469,
      175.24632263183594,
      331.6031494140625
    ],
    "y_values_20": [
      97.71183776855469,
      175.24632263183594,
      331.6031494140625
    ],
    "y_values_80": [
      97.71183776855469,
      175.24632263183594,
      331.6031494140625
    ],
    "timestamp": "2026-03-25 06:57:48",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1167.6370849609375,
      1417.8104248046875,
      1970.2938232421875
    ],
    "y_values_20": [
      1167.6370849609375,
      1417.8104248046875,
      1970.2938232421875
    ],
    "y_values_80": [
      1167.6370849609375,
      1417.8104248046875,
      1970.2938232421875
    ],
    "timestamp": "2026-03-25 06:58:39",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      1072.7044677734375,
      1326.4580078125,
      1879.5184326171875
    ],
    "y_values_20": [
      1072.7044677734375,
      1326.4580078125,
      1879.5184326171875
    ],
    "y_values_80": [
      1072.7044677734375,
      1326.4580078125,
      1879.5184326171875
    ],
    "timestamp": "2026-03-25 06:59:28",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      23.349929809570312,
      47.93915939331055,
      93.9185791015625
    ],
    "y_values_20": [
      23.10479164123535,
      47.75263214111328,
      93.9185791015625
    ],
    "y_values_80": [
      23.671571731567383,
      48.12568664550781,
      93.9185791015625
    ],
    "timestamp": "2026-03-25 06:59:44",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  }
]
**************************************
     BENCHMARKING MEMORY for FUSED_LINEAR_CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      7141.36962890625,
      7268.73291015625,
      7523.45947265625
    ],
    "y_values_20": [
      7141.36962890625,
      7268.73291015625,
      7523.45947265625
    ],
    "y_values_80": [
      7141.36962890625,
      7268.73291015625,
      7523.45947265625
    ],
    "timestamp": "2026-03-25 07:00:45",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      7141.369140625,
      7268.732421875,
      7523.458984375
    ],
    "y_values_20": [
      7141.369140625,
      7268.732421875,
      7523.458984375
    ],
    "y_values_80": [
      7141.369140625,
      7268.732421875,
      7523.458984375
    ],
    "timestamp": "2026-03-25 07:01:42",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      6108.03759765625,
      8208.068359375,
      14284.1298828125
    ],
    "y_values_20": [
      6108.03759765625,
      8208.068359375,
      14284.1298828125
    ],
    "y_values_80": [
      6108.03759765625,
      8208.068359375,
      14284.1298828125
    ],
    "timestamp": "2026-03-25 07:01:57",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.6.4"
  }
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant