其他
CUDA高性能计算经典问题②:前缀和
void PrefixSum(const int32_t* input, size_t n, int32_t* output) {
int32_t sum = 0;
for (size_t i = 0; i < n; ++i) {
sum += input[i];
output[i] = sum;
}
}
2
ScanThenFan
将存储在Global Memory中的数据分为多个Parts,每个Part由一个Thread Block单独做内部的Scan,并将该Part的内部Sum存储到Global Memory中的PartSum数组中
对这个PartSum数组做Scan,我们使用BaseSum标识这个Scan后的数组
每个Part的每个元素都加上对应的BaseSum
3
Baseline
__global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part, int32_t* output, size_t n,
size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
// this part process input[part_begin:part_end]
// store sum to part[part_i], output[part_begin:part_end]
size_t part_begin = part_i * blockDim.x;
size_t part_end = min((part_i + 1) * blockDim.x, n);
if (threadIdx.x == 0) { // naive implemention
int32_t acc = 0;
for (size_t i = part_begin; i < part_end; ++i) {
acc += input[i];
output[i] = acc;
}
part[part_i] = acc;
}
}
}
__global__ void ScanPartSumKernel(int32_t* part, size_t part_num) {
int32_t acc = 0;
for (size_t i = 0; i < part_num; ++i) {
acc += part[i];
part[i] = acc;
}
}
__global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
if (part_i == 0) {
continue;
}
int32_t index = part_i * blockDim.x + threadIdx.x;
if (index < n) {
output[index] += part[part_i - 1];
}
}
}
// for i in range(n):
// output[i] = input[0] + input[1] + ... + input[i]
void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
// use buffer[0:part_num] to save the metric of part
int32_t* part = buffer;
// after following step, part[i] = part_sum[i]
ScanAndWritePartSumKernel<<<block_num, part_size>>>(input, part, output, n,
part_num);
// after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
ScanPartSumKernel<<<1, 1>>>(part, part_num);
// make final result
AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
}
4
Shared Memory
__device__ void ScanBlock(int32_t* shm) {
if (threadIdx.x == 0) { // naive implemention
int32_t acc = 0;
for (size_t i = 0; i < blockDim.x; ++i) {
acc += shm[i];
shm[i] = acc;
}
}
__syncthreads();
}
__global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
int32_t* output, size_t n,
size_t part_num) {
extern __shared__ int32_t shm[];
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
// store this part input to shm
size_t index = part_i * blockDim.x + threadIdx.x;
shm[threadIdx.x] = index < n ? input[index] : 0;
__syncthreads();
// scan on shared memory
ScanBlock(shm);
__syncthreads();
// write result
if (index < n) {
output[index] = shm[threadIdx.x];
}
if (threadIdx.x == blockDim.x - 1) {
part[part_i] = shm[threadIdx.x];
}
}
}
5
ScanBlock
按照Warp组织,每个Warp内部先做Scan,将每个Warp的和存储到Shared Memory中,称为WarpSum
启动一个单独的Warp对WarpSum进行Scan
每个Warp将最终结果加上上一个Warp对应的WarpSum
__device__ void ScanWarp(int32_t* shm_data, int32_t lane) {
if (lane == 0) { // naive implemention
int32_t acc = 0;
for (int32_t i = 0; i < 32; ++i) {
acc += shm_data[i];
shm_data[i] = acc;
}
}
}
__device__ void ScanBlock(int32_t* shm_data) {
int32_t warp_id = threadIdx.x >> 5;
int32_t lane = threadIdx.x & 31; // 31 = 00011111
__shared__ int32_t warp_sum[32]; // blockDim.x / WarpSize = 32
// scan each warp
ScanWarp(shm_data, lane);
__syncthreads();
// write sum of each warp to warp_sum
if (lane == 31) {
warp_sum[warp_id] = *shm_data;
}
__syncthreads();
// use a single warp to scan warp_sum
if (warp_id == 0) {
ScanWarp(warp_sum + lane, lane);
}
__syncthreads();
// add base
if (warp_id > 0) {
*shm_data += warp_sum[warp_id - 1];
}
__syncthreads();
}
6
ScanWarp
__device__ void ScanWarp(int32_t* shm_data) {
int32_t lane = threadIdx.x & 31;
volatile int32_t* vshm_data = shm_data;
if (lane >= 1) {
vshm_data[0] += vshm_data[-1];
}
__syncwarp();
if (lane >= 2) {
vshm_data[0] += vshm_data[-2];
}
__syncwarp();
if (lane >= 4) {
vshm_data[0] += vshm_data[-4];
}
__syncwarp();
if (lane >= 8) {
vshm_data[0] += vshm_data[-8];
}
__syncwarp();
if (lane >= 16) {
vshm_data[0] += vshm_data[-16];
}
__syncwarp();
}
7
ZeroPadding
__device__ void ScanBlock(int32_t* shm_data) {
int32_t warp_id = threadIdx.x >> 5;
int32_t lane = threadIdx.x & 31; // 31 = 00011111
extern __shared__ int32_t warp_sum[]; // warp_sum[32]
// scan each warp
ScanWarp(shm_data);
__syncthreads();
// write sum of each warp to warp_sum
if (lane == 31) {
warp_sum[warp_id] = *shm_data;
}
__syncthreads();
// use a single warp to scan warp_sum
if (warp_id == 0) {
ScanWarp(warp_sum + lane);
}
__syncthreads();
// add base
if (warp_id > 0) {
*shm_data += warp_sum[warp_id - 1];
}
__syncthreads();
}
__global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
int32_t* output, size_t n,
size_t part_num) {
// the first 32 is used to save warp sum
extern __shared__ int32_t shm[];
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
// store this part input to shm
size_t index = part_i * blockDim.x + threadIdx.x;
shm[32 + threadIdx.x] = index < n ? input[index] : 0;
__syncthreads();
// scan on shared memory
ScanBlock(shm + 32 + threadIdx.x);
__syncthreads();
// write result
if (index < n) {
output[index] = shm[32 + threadIdx.x];
}
if (threadIdx.x == blockDim.x - 1) {
part[part_i] = shm[32 + threadIdx.x];
}
}
}
__global__ void ScanPartSumKernel(int32_t* part, size_t part_num) {
int32_t acc = 0;
for (size_t i = 0; i < part_num; ++i) {
acc += part[i];
part[i] = acc;
}
}
__global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
if (part_i == 0) {
continue;
}
int32_t index = part_i * blockDim.x + threadIdx.x;
if (index < n) {
output[index] += part[part_i - 1];
}
}
}
// for i in range(n):
// output[i] = input[0] + input[1] + ... + input[i]
void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
// use buffer[0:part_num] to save the metric of part
int32_t* part = buffer;
// after following step, part[i] = part_sum[i]
size_t shm_size = (32 + part_size) * sizeof(int32_t);
ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
input, part, output, n, part_num);
// after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
ScanPartSumKernel<<<1, 1>>>(part, part_num);
// make final result
AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
}
// for i in range(n):
// output[i] = input[0] + input[1] + ... + input[i]
void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
// use buffer[0:part_num] to save the metric of part
int32_t* part = buffer;
// after following step, part[i] = part_sum[i]
size_t warp_num = part_size / 32;
size_t shm_size = (16 + 32 + warp_num * (16 + 32)) * sizeof(int32_t);
ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
input, part, output, n, part_num);
// after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
ScanPartSumKernel<<<1, 1>>>(part, part_num);
// make final result
AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
}
__device__ void ScanWarp(int32_t* shm_data) {
volatile int32_t* vshm_data = shm_data;
vshm_data[0] += vshm_data[-1];
vshm_data[0] += vshm_data[-2];
vshm_data[0] += vshm_data[-4];
vshm_data[0] += vshm_data[-8];
vshm_data[0] += vshm_data[-16];
}
__device__ void ScanBlock(int32_t* shm_data) {
int32_t warp_id = threadIdx.x >> 5;
int32_t lane = threadIdx.x & 31;
extern __shared__ int32_t warp_sum[]; // 16 zero padding
// scan each warp
ScanWarp(shm_data);
__syncthreads();
// write sum of each warp to warp_sum
if (lane == 31) {
warp_sum[16 + warp_id] = *shm_data;
}
__syncthreads();
// use a single warp to scan warp_sum
if (warp_id == 0) {
ScanWarp(warp_sum + 16 + lane);
}
__syncthreads();
// add base
if (warp_id > 0) {
*shm_data += warp_sum[16 + warp_id - 1];
}
__syncthreads();
}
__global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
int32_t* output, size_t n,
size_t part_num) {
// the first 16 + 32 is used to save warp sum
extern __shared__ int32_t shm[];
int32_t warp_id = threadIdx.x >> 5;
int32_t lane = threadIdx.x & 31;
// initialize the zero padding
if (threadIdx.x < 16) {
shm[threadIdx.x] = 0;
}
if (lane < 16) {
shm[(16 + 32) + warp_id * (16 + 32) + lane] = 0;
}
__syncthreads();
// process each part
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
// store this part input to shm
size_t index = part_i * blockDim.x + threadIdx.x;
int32_t* myshm = shm + (16 + 32) + warp_id * (16 + 32) + 16 + lane;
*myshm = index < n ? input[index] : 0;
__syncthreads();
// scan on shared memory
ScanBlock(myshm);
__syncthreads();
// write result
if (index < n) {
output[index] = *myshm;
}
if (threadIdx.x == blockDim.x - 1) {
part[part_i] = *myshm;
}
}
}
8
Recursion
// for i in range(n):
// output[i] = input[0] + input[1] + ... + input[i]
void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
// use buffer[0:part_num] to save the metric of part
int32_t* part = buffer;
// after following step, part[i] = part_sum[i]
size_t warp_num = part_size / 32;
size_t shm_size = (16 + 32 + warp_num * (16 + 32)) * sizeof(int32_t);
ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
input, part, output, n, part_num);
if (part_num >= 2) {
// after following step
// part[i] = part_sum[0] + part_sum[1] + ... + part_sum[i]
ScanThenFan(part, buffer + part_num, part, part_num);
// make final result
AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
}
}
9
WarpShuffle
__device__ int32_t ScanWarp(int32_t val) {
int32_t lane = threadIdx.x & 31;
int32_t tmp = __shfl_up_sync(0xffffffff, val, 1);
if (lane >= 1) {
val += tmp;
}
tmp = __shfl_up_sync(0xffffffff, val, 2);
if (lane >= 2) {
val += tmp;
}
tmp = __shfl_up_sync(0xffffffff, val, 4);
if (lane >= 4) {
val += tmp;
}
tmp = __shfl_up_sync(0xffffffff, val, 8);
if (lane >= 8) {
val += tmp;
}
tmp = __shfl_up_sync(0xffffffff, val, 16);
if (lane >= 16) {
val += tmp;
}
return val;
}
10
PTX
// 声明寄存器
.reg .pred %p<11>;
.reg .b32 %r<39>;
// 读取参数到r35寄存器
ld.param.u32 %r35, [_Z8ScanWarpi_param_0];
// 读取threadIdx.x到r18寄存器
mov.u32 %r18, %tid.x;
// r1寄存器存储 lane = threadIdx.x & 31
and.b32 %r1, %r18, 31;
// r19寄存器存储0
mov.u32 %r19, 0;
// r20寄存器存储1
mov.u32 %r20, 1;
// r21寄存器存储-1
mov.u32 %r21, -1;
// r2|p1 = __shfl_up_sync(val, delta=1, 0, membermask=-1)
// 如果src lane在范围内,存储结果到r2中,并设置p1为True, 否则设置p1为False
// r2对应于我们代码中的tmp
shfl.sync.up.b32 %r2|%p1, %r35, %r20, %r19, %r21;
// p6 = (lane == 0)
setp.eq.s32 %p6, %r1, 0;
// 如果p6为真,则跳转到BB0_2
@%p6 bra BB0_2;
// val += tmp
add.s32 %r35, %r2, %r35;
// 偏移2
BB0_2:
mov.u32 %r23, 2;
shfl.sync.up.b32 %r5|%p2, %r35, %r23, %r19, %r21;
setp.lt.u32 %p7, %r1, 2;
@%p7 bra BB0_4;
add.s32 %r35, %r5, %r35;
...
__device__ __forceinline__ int32_t ScanWarp(int32_t val) {
int32_t result;
asm("{"
".reg .s32 r<5>;"
".reg .pred p<5>;"
"shfl.sync.up.b32 r0|p0, %1, 1, 0, -1;"
"@p0 add.s32 r0, r0, %1;"
"shfl.sync.up.b32 r1|p1, r0, 2, 0, -1;"
"@p1 add.s32 r1, r1, r0;"
"shfl.sync.up.b32 r2|p2, r1, 4, 0, -1;"
"@p2 add.s32 r2, r2, r1;"
"shfl.sync.up.b32 r3|p3, r2, 8, 0, -1;"
"@p3 add.s32 r3, r3, r2;"
"shfl.sync.up.b32 r4|p4, r3, 16, 0, -1;"
"@p4 add.s32 r4, r4, r3;"
"mov.s32 %0, r4;"
"}"
: "=r"(result)
: "r"(val));
return result;
}
__device__ __forceinline__ int32_t ScanBlock(int32_t val) {
int32_t warp_id = threadIdx.x >> 5;
int32_t lane = threadIdx.x & 31;
extern __shared__ int32_t warp_sum[];
// scan each warp
val = ScanWarp(val);
__syncthreads();
// write sum of each warp to warp_sum
if (lane == 31) {
warp_sum[warp_id] = val;
}
__syncthreads();
// use a single warp to scan warp_sum
if (warp_id == 0) {
warp_sum[lane] = ScanWarp(warp_sum[lane]);
}
__syncthreads();
// add base
if (warp_id > 0) {
val += warp_sum[warp_id - 1];
}
__syncthreads();
return val;
}
__global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
int32_t* output, size_t n,
size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
size_t index = part_i * blockDim.x + threadIdx.x;
int32_t val = index < n ? input[index] : 0;
val = ScanBlock(val);
__syncthreads();
if (index < n) {
output[index] = val;
}
if (threadIdx.x == blockDim.x - 1) {
part[part_i] = val;
}
}
}
__global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
if (part_i == 0) {
continue;
}
int32_t index = part_i * blockDim.x + threadIdx.x;
if (index < n) {
output[index] += part[part_i - 1];
}
}
}
// for i in range(n):
// output[i] = input[0] + input[1] + ... + input[i]
void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
// use buffer[0:part_num] to save the metric of part
int32_t* part = buffer;
// after following step, part[i] = part_sum[i]
size_t shm_size = 32 * sizeof(int32_t);
ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
input, part, output, n, part_num);
if (part_num >= 2) {
// after following step
// part[i] = part_sum[0] + part_sum[1] + ... + part_sum[i]
ScanThenFan(part, buffer + part_num, part, part_num);
// make final result
AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
}
}
11
ReduceThenScan
__global__ void ReducePartSumKernel(const int32_t* input, int32_t* part_sum,
int32_t* output, size_t n,
size_t part_num) {
using BlockReduce = cub::BlockReduce<int32_t, 1024>;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
size_t index = part_i * blockDim.x + threadIdx.x;
int32_t val = index < n ? input[index] : 0;
int32_t sum = BlockReduce(temp_storage).Sum(val);
if (threadIdx.x == 0) {
part_sum[part_i] = sum;
}
__syncthreads();
}
}
__global__ void ScanWithBaseSum(const int32_t* input, int32_t* part_sum,
int32_t* output, size_t n, size_t part_num) {
for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
size_t index = part_i * blockDim.x + threadIdx.x;
int32_t val = index < n ? input[index] : 0;
val = ScanBlock(val);
__syncthreads();
if (part_i >= 1) {
val += part_sum[part_i - 1];
}
if (index < n) {
output[index] = val;
}
}
}
void ReduceThenScan(const int32_t* input, int32_t* buffer, int32_t* output,
size_t n) {
size_t part_size = 1024; // tuned
size_t part_num = (n + part_size - 1) / part_size;
size_t block_num = std::min<size_t>(part_num, 128);
int32_t* part_sum = buffer; // use buffer[0:part_num]
if (part_num >= 2) {
ReducePartSumKernel<<<block_num, part_size>>>(input, part_sum, output, n,
part_num);
ReduceThenScan(part_sum, buffer + part_num, part_sum, part_num);
}
ScanWithBaseSum<<<block_num, part_size, 32 * sizeof(int32_t)>>>(
input, part_sum, output, n, part_num);
}
__global__ void ReducePartSumKernelSinglePass(const int32_t* input,
int32_t* g_part_sum, size_t n,
size_t part_size) {
// this block process input[part_begin:part_end]
size_t part_begin = blockIdx.x * part_size;
size_t part_end = min((blockIdx.x + 1) * part_size, n);
// part_sum
int32_t part_sum = 0;
for (size_t i = part_begin + threadIdx.x; i < part_end; i += blockDim.x) {
part_sum += input[i];
}
using BlockReduce = cub::BlockReduce<int32_t, 1024>;
__shared__ typename BlockReduce::TempStorage temp_storage;
part_sum = BlockReduce(temp_storage).Sum(part_sum);
__syncthreads();
if (threadIdx.x == 0) {
g_part_sum[blockIdx.x] = part_sum;
}
}
__global__ void ScanWithBaseSumSinglePass(const int32_t* input,
int32_t* g_base_sum, int32_t* output,
size_t n, size_t part_size,
bool debug) {
// base sum
__shared__ int32_t base_sum;
if (threadIdx.x == 0) {
if (blockIdx.x == 0) {
base_sum = 0;
} else {
base_sum = g_base_sum[blockIdx.x - 1];
}
}
__syncthreads();
// this block process input[part_begin:part_end]
size_t part_begin = blockIdx.x * part_size;
size_t part_end = (blockIdx.x + 1) * part_size;
for (size_t i = part_begin + threadIdx.x; i < part_end; i += blockDim.x) {
int32_t val = i < n ? input[i] : 0;
val = ScanBlock(val);
if (i < n) {
output[i] = val + base_sum;
}
__syncthreads();
if (threadIdx.x == blockDim.x - 1) {
base_sum += val;
}
__syncthreads();
}
}
void ReduceThenScanTwoPass(const int32_t* input, int32_t* part_sum,
int32_t* output, size_t n) {
size_t part_num = 1024;
size_t part_size = (n + part_num - 1) / part_num;
ReducePartSumKernelSinglePass<<<part_num, 1024>>>(input, part_sum, n,
part_size);
ScanWithBaseSumSinglePass<<<1, 1024, 32 * sizeof(int32_t)>>>(
part_sum, nullptr, part_sum, part_num, part_num, true);
ScanWithBaseSumSinglePass<<<part_num, 1024, 32 * sizeof(int32_t)>>>(
input, part_sum, output, n, part_size, false);
}