模板内核函数

模板内核函数的参数

使用模板内核函数能够显式声明和调整变量,非常适用于一些需要(在 kernel 函数外面)动态划分,但是在 kernel 函数中却又为常量的变量

  1. 可以在 kernel 外面进行动态调整【增加程序适用性】
  2. 会进行预编译替换常量值【类似 #define】减少运行时间

这两个条件,简直为线程块相关数值变量量身定制!

运用设想

在 kernel 函数中提取线程块相关数值变量

对 kernel 函数进行包装

在包装函数内对相关变量进行动态划分

运用示例

下面这个实例是本人在对卷积算子进行 im2col 优化操作时看到的一个 SEMM 算子

kernel 函数:

template <const int BM, const int BN, const int BK, const int TM, const int TN>
// __launch_bounds__((BM * BN) / (TM * TN), 1)
__global__ void 
    sgemm_blocktiling_2d_kernel(half *A,
                                half *B,
                                half *C,
                                int M,
                                int N,
                                int K,
                                int batch,
                                int kernel_num,
                                int y_height,
                                int y_width)
{
    // BM, BN, BK:block tile 的尺寸,分别代表 block 处理的 M 方向、N 方向和 K 方向的大小。
    // TM, TN:单个线程计算的子块大小,每个线程计算 TM × TN 个元素。
    // A、B、C:输入矩阵 A、B 和输出矩阵 C(存储在全局内存中)。
    // M、N、K:矩阵的维度,使 C = A × B,其中:
        // A 维度为 (M × K)
        // B 维度为 (K × N)
        // C 维度为 (M × N)


    // the output block that w e want to compute in this threadblock
    const uint c_row = blockIdx.y;  // 当前 block 负责的 C 的行块索引
    const uint c_col = blockIdx.x;  // 当前 block 负责的 C 的列块索引

    // 每个 thread block 计算 BM × BN 大小的 C 的子块,c_row 和 c_col 决定了该 block 处理 C 矩阵的哪一部分。

    const uint num_threads_block_tile = (BM * BN) / (TM * TN);
    // 一个 block 共有 (BM * BN) / (TM * TN) 个线程,即 block 计算的区域被 TM × TN 大小的线程组分割。

    __shared__ half A_shared[BM * BK];
    __shared__ half B_shared[BK * BN];
    // 共享内存用于缓存 A 和 B 的子块,减少全局内存访问,提高计算效率。
 
    // the inner row & col that we're accessing in this thread
    const uint thread_row = threadIdx.x / (BN / TN);  // 线程在 block tile 内的行索引
    const uint thread_col = threadIdx.x % (BN / TN);  // 线程在 block tile 内的列索引

    // advance pointers to the starting positions
    // 调整 A 和 B 指针,使其指向 block 负责的区域,同时计算 C 的起始索引。
    A += c_row * BM * K;
    B += c_col * BN;
    int global_c_index = c_row * BM * N + c_col * BN;

    // use to avoid out-of-bounds accesses
    // 用于避免越界访问,确保加载 A 和 B 的元素不超过矩阵尺寸。
    int global_m_pos = c_row * BM * K;
    int global_n_pos = c_col * BN;
    const uint m_size = M * K;
    const uint n_size = N * K;

    assert((BM * BN) / (TM * TN) == blockDim.x);

    // 通过线程索引计算 A 和 B 在共享内存中的存储位置,采用 stride_a 和 stride_b 进行跨步加载。
    const uint A_inner_row = threadIdx.x / BK; // warp-level GMEM coalescing
    const uint A_inner_col = threadIdx.x % BK;
    const uint stride_a = num_threads_block_tile / BK;
    const uint B_inner_row = threadIdx.x / BN; // warp-level GMEM coalescing
    const uint B_inner_col = threadIdx.x % BN;
    const uint stride_b = num_threads_block_tile / BN;

    // allocate thread-local cache for results in registerfile
    // 分配寄存器用于计算
    // 每个线程存储 TM × TN 个计算结果,避免频繁访问全局内存。
    float thread_results[TM * TN] = {0.0};
    half reg_m[TM] = {0.0};
    half reg_n[TN] = {0.0};

    // outer loop over block tiles
    for (uint bk_idx = 0; bk_idx < K; bk_idx += BK)
    {
        // load the next block of the input matrices into shared memory
        for (uint load_offset = 0; load_offset < BM; load_offset += stride_a)
        {
            A_shared[(A_inner_row + load_offset) * BK + A_inner_col] =
                (global_m_pos + (A_inner_row + load_offset) * K + A_inner_col < m_size) ? A[(A_inner_row + load_offset) * K + A_inner_col] : __float2half(0.0f);;
        }
        for (uint load_offset = 0; load_offset < BK; load_offset += stride_b)
        {
            B_shared[(B_inner_row + load_offset) * BN + B_inner_col] =
                (global_n_pos + (B_inner_row + load_offset) * N + B_inner_col < n_size) ? B[(B_inner_row + load_offset) * N + B_inner_col] : __float2half(0.0f);;
        }

        // wait for all threads to finish loading
        __syncthreads();

        // advance the pointers
        A += BK;
        B += BK * N;
        global_m_pos += BK;
        global_n_pos += BK * N;

        // compute the partial sum
        for (uint dot_idx = 0; dot_idx < BK; dot_idx++)
        {
            // load relevant As & Bs entries into registers
            for (uint i = 0; i < TM; i++)
            {
                reg_m[i] = A_shared[(thread_row * TM + i) * BK + dot_idx];
            }
            for (uint i = 0; i < TN; i++)
            {
                reg_n[i] = B_shared[dot_idx * BN + thread_col * TN + i];
            }
            
            // float tem = 0;
            // perform outer product on register cache, accumulate
            // into threadResults
            for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++)
            {
                for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++)
                {
                    // tem +=
                    thread_results[res_idx_m * TN + res_idx_n] += static_cast<float>(reg_m[res_idx_m]) * static_cast<float>(reg_n[res_idx_n]);
                }
            }
        }

        // wait for all threads to finish computing
        __syncthreads();
    }

    int inner_y_size = y_height * y_width;
    int res_inner_index, g_index, batch_id, channel_id, inner_offset;

    int conv_idx;

    if (global_c_index >= M * N)
    {
        return;
    }

    for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++)
    {
        for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++)
        {
            if (c_row * BM + thread_row * TM + res_idx_m < M && c_col * BN + thread_col * TN + res_idx_n < N)
            {
                res_inner_index = (thread_row * TM + res_idx_m) * N + thread_col * TN + res_idx_n;
                g_index = global_c_index + res_inner_index;
                inner_offset = g_index % inner_y_size;
                batch_id = (g_index % (inner_y_size * batch)) / inner_y_size;
                channel_id = g_index / (inner_y_size * batch);
                conv_idx = batch_id * (kernel_num * y_height * y_width) + channel_id * (y_height * y_width) + inner_offset;
                C[conv_idx] = (half)thread_results[res_idx_m * TN + res_idx_n];
            }
        }
    }
}

包装函数:

#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)

void cuda_gemm(half *A,     // 卷积核
    half *B,                // 转置后的矩阵
    half *C,                // 输出结果
    int M,                  // 转置后卷积核的高
    int N,                  // 转置后矩阵的宽
    int K,                  // 转置卷积核的宽
    int batch,              // 转置矩阵的高?
    int kernel_num,         // 卷积核数量
    int y_height,           // 输出图像的高
    int y_width)            // 输出图像的宽
{
    const uint BK = 8;
    const uint TM = 8;
    const uint TN = 8;
    if (M >= 128 && N >= 128)
    {
        const uint BM = 128;
        const uint BN = 128;
        dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
        dim3 block_size((BM * BN) / (TM * TN));
        sgemm_blocktiling_2d_kernel<BM, BN, BK, TM, TN>
        <<<grid_size, block_size>>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width);
     }
     else
     {
        const uint BM = 64;
        const uint BN = 64;
        dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM));
        dim3 block_size((BM * BN) / (TM * TN));
        sgemm_blocktiling_2d_kernel<BM, BN, BK, TM, TN>
        <<<grid_size, block_size>>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width);
      }
}

如何声明模板变量和调用模板 kernel

一下以上面的 SEMM 为例进行解释

1. 模板参数如何定义?

模板 GPU 内核的定义方式通常如下:

template <int BLOCK_SIZE>
__global__ void addKernel(int *d_a, int *d_b, int *d_c, int size) {
int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;
if (idx < size) {
d_c[idx] = d_a[idx] + d_b[idx];
}
}

这里的 BLOCK_SIZE 是一个模板参数,它是 编译时常量,在编译时就被替换,不会增加运行时的计算开销。

2. 模板 GPU 内核如何调用?

在调用模板内核时,必须显式指定模板参数:

addKernel<256><<<blocksPerGrid, 256>>>(d_a, d_b, d_c, size);

解释

  • <256> 指定了模板参数 BLOCK_SIZE 的值,这里 BLOCK_SIZE = 256
  • <<<blocksPerGrid, 256>>> 指定了 CUDA 网格 (Grid) 和块 (Block) 的大小,其中 blockDim.x = 256,与 BLOCK_SIZE 保持一致。

3. 多个模板参数的含义

在更复杂的 GPU 计算中,可能会有多个模板参数。例如:

template <const int BM, const int BN, const int BK, const int TM, const int TN>
__global__ void sgemm_blocktiling_2d_kernel(half *A, half *B, half *C, int M, int N, int K) {
    // 计算线程块索引
    const uint c_row = blockIdx.y;
    const uint c_col = blockIdx.x;

    // 计算线程索引
    const uint num_threads_block_tile = (BM * BN) / (TM * TN);

    // 共享内存分配
    __shared__ half A_shared[BM * BK];
    __shared__ half B_shared[BK * BN];

    // 计算线程在块中的行列索引
    const uint thread_row = threadIdx.x / (BN / TN);
    const uint thread_col = threadIdx.x % (BN / TN);

    // 计算偏移量
    A += c_row * BM * K;
    B += c_col * BN;
    int global_c_index = c_row * BM * N + c_col * BN;

    assert((BM * BN) / (TM * TN) == blockDim.x);
}

调用:

sgemm_blocktiling_2d_kernel<128, 128, 32, 8, 8> <<<grid_size, block_size>>>(A, B, C, M, N, K);

解释各个模板参数的意义

参数作用
BM线程块 (Block) 处理的矩阵行数
BN线程块 (Block) 处理的矩阵列数
BK分块计算时的 K 维度大小
TM一个线程计算的矩阵行数
TN一个线程计算的矩阵列数

在编译时:

  • BM = 128BN = 128,表示每个线程块计算 128×128 的矩阵块
  • BK = 32,表示计算时的分块大小(即每次加载 32 个 K 方向的数据)。
  • TM = 8TN = 8,表示每个线程计算 8×8 的子矩阵

4. 确定模板参数的来源

(1) 查看内核代码

  • sgemm_blocktiling_2d_kernel 代码中,BM, BN, BK, TM, TN 是用于矩阵分块计算的参数。
  • BM, BN 影响 每个线程块的大小
  • TM, TN 影响 每个线程的计算任务
  • BK 影响 分块计算时的 K 维度大小

(2) 查看调用代码

sgemm_blocktiling_2d_kernel<128, 128, 32, 8, 8> <<<grid_size, block_size>>>(A, B, C, M, N, K);
  • <128, 128, 32, 8, 8>:告诉编译器如何优化计算
  • 这些参数被代入 sgemm_blocktiling_2d_kernel,用于计算线程布局、共享内存大小、寄存器优化等

(3) 确定 CUDA 网格和块的大小

通常,我们会根据 BMBN 来计算 CUDA 网格:

dim3 grid_size((N + BN - 1) / BN, (M + BM - 1) / BM);
dim3 block_size((BM * BN) / (TM * TN));

这样,CUDA 线程块的数量每个线程块的线程数 都是根据模板参数计算出来的。

5. 模板参数 vs 普通参数

类型作用传递方式计算时间
模板参数编译期常量kernel<128, 128, 32, 8, 8><<<grid, block>>>(...)编译时计算
普通参数运行时变量kernel<<<grid, block>>>(M, N, K, ...)运行时计算

为什么用模板参数?

  1. 减少运行时计算BM, BN, BK, TM, TN固定的,提前传递给编译器,减少运行时的计算。
  2. 优化性能:编译器可以利用 BM, BN, TM, TN 进行 循环展开、寄存器分配优化,提高性能。
  3. 提高灵活性:可以根据不同硬件选择 最优的参数组合,在不同 GPU 设备上优化执行效率。

6. 结论

如何确定模板参数的意义?

  1. 查看模板内核代码
    • 找到 模板参数的使用位置,看它们如何影响 线程布局、内存访问和计算方式
  2. 查看内核调用方式
    • 例如 kernel<128, 128, 32, 8, 8><<<grid, block>>>(...),可以推测 128×128 计算块、32 维度分块、8×8 线程任务
  3. 结合 CUDA 网格和块计算
    • 通常 BM, BN 影响 grid 计算TM, TN 影响 block 内计算
2025年3月29日 创建
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇
Copyright 2025-2025 @ Ziyang
Running Time days H M S