Skip to content

Latest commit

 

History

History

rope

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Sigmoid

0x00 说明

RoPE基础版本,包含了RoPE在Llama的最小实现。

包含以下内容:

  • rope_f32_kernel
  • rope_f32x4_kernel(float4向量化版本)
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 rope.py

输出:

----------------------------------------------------------------------------------------------------
                                        M=4096, N=512
----------------------------------------------------------------------------------------------------
             out_f32: ['1.066324    ', '-1.06176651 ', '-0.16482249 '], time:0.006247ms
      out_f32x4_pack: ['1.066324    ', '-1.06176651 ', '-0.16482249 '], time:0.005484ms
          out_f32_th: ['1.066324    ', '-1.06176651 ', '-0.16482249 '], time:0.734866ms
----------------------------------------------------------------------------------------------------
                                        M=4096, N=1024
----------------------------------------------------------------------------------------------------
             out_f32: ['-0.52068412 ', '1.20729053  ', '0.93223286  '], time:0.010335ms
      out_f32x4_pack: ['-0.52068412 ', '1.20729053  ', '0.93223286  '], time:0.008714ms
          out_f32_th: ['-0.52068412 ', '1.20729053  ', '0.93223286  '], time:1.447463ms
----------------------------------------------------------------------------------------------------
                                        M=8192, N=512
----------------------------------------------------------------------------------------------------
             out_f32: ['-0.19190802 ', '0.43925601  ', '0.58010447  '], time:0.010288ms
      out_f32x4_pack: ['-0.19190802 ', '0.43925601  ', '0.58010447  '], time:0.008750ms
          out_f32_th: ['-0.19190802 ', '0.43925601  ', '0.58010447  '], time:1.434934ms
----------------------------------------------------------------------------------------------------
                                        M=8192, N=1024
----------------------------------------------------------------------------------------------------
             out_f32: ['1.07467616  ', '-0.41201836 ', '-0.34494475 '], time:0.018394ms
      out_f32x4_pack: ['1.07467616  ', '-0.41201836 ', '-0.34494475 '], time:0.015330ms
          out_f32_th: ['1.07467616  ', '-0.41201836 ', '-0.34494475 '], time:2.518094ms
----------------------------------------------------------------------------------------------------