Skip to content

Commit

Permalink
Add Enzyme Forward mode custom rule.
Browse files Browse the repository at this point in the history
Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: "William S. Moses" <[email protected]>
  • Loading branch information
3 people committed Apr 28, 2024
1 parent 2a3cec9 commit 3e2a990
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
39 changes: 38 additions & 1 deletion ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,49 @@ module EnzymeCoreExt
using CUDA
import CUDA: GPUCompiler, CUDABackend

isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore)
if isdefined(Base, :get_extension)
using EnzymeCore
using EnzymeCore.EnzymeRules
else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end

function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type))
mi = GPUCompiler.methodinstance(F, TT)
return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device()))
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(Val(N)) do _
res
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

GC.@preserve args begin
args = ((cudaconvert(a) for a in args)...,)
T2 = (F, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(metaf, TT2)
res = cuf(ofn.val.f, args...; kwargs...)
end

return nothing
end

end # module

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand Down
31 changes: 31 additions & 0 deletions test/libraries/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
using EnzymeCore
using GPUCompiler
using Enzyme

@testset "compiler_job_from_backend" begin
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
end

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x)
@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end

A = CUDA.rand(64)
dA = CUDA.ones(64)
A .= (1:1:64)
dA .= 1
Enzyme.autodiff(Forward, square!, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A = CUDA.rand(32)
dA = CUDA.ones(32)
dA2 = CUDA.ones(32)
A .= (1:1:32)
dA .= 1
dA2 .= 3
Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2)))
@test all(dA .≈ (2:2:64))
@test all(dA2 .≈ 3*(2:2:64))

0 comments on commit 3e2a990

Please sign in to comment.