RoPE基础版本,包含了RoPE在Llama的最小实现。
包含以下内容:
- rope_f32_kernel
- rope_f32x4_kernel(float4向量化版本)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 rope.py
输出:
----------------------------------------------------------------------------------------------------
M=4096, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.006247ms
out_f32x4_pack: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.005484ms
out_f32_th: ['1.066324 ', '-1.06176651 ', '-0.16482249 '], time:0.734866ms
----------------------------------------------------------------------------------------------------
M=4096, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.010335ms
out_f32x4_pack: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:0.008714ms
out_f32_th: ['-0.52068412 ', '1.20729053 ', '0.93223286 '], time:1.447463ms
----------------------------------------------------------------------------------------------------
M=8192, N=512
----------------------------------------------------------------------------------------------------
out_f32: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.010288ms
out_f32x4_pack: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:0.008750ms
out_f32_th: ['-0.19190802 ', '0.43925601 ', '0.58010447 '], time:1.434934ms
----------------------------------------------------------------------------------------------------
M=8192, N=1024
----------------------------------------------------------------------------------------------------
out_f32: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.018394ms
out_f32x4_pack: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:0.015330ms
out_f32_th: ['1.07467616 ', '-0.41201836 ', '-0.34494475 '], time:2.518094ms
----------------------------------------------------------------------------------------------------