Skip to content

Commit

Permalink
Add some vararg of fixed size to fix inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 19, 2023
1 parent 682cd9b commit b85f46e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions ext/CUDAEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ end

# Expand Enzyme activity and arguments to pass each pointer (primal or shadow) as an indivdual argument
@inline expand() = ()
@inline expand(a::Const, args...) = (typeof(a), a.val, expand(args...)...)
@inline expand(a::Duplicated, args...) = (typeof(a), a.val, a.dval, expand(args...)...)
@inline expand(a::DuplicatedNoNeed, args...) =
@inline expand(a::Const, args::Vararg{Any, N}) where N = (typeof(a), a.val, expand(args...)...)
@inline expand(a::Duplicated, args::Vararg{Any, N}) where N = (typeof(a), a.val, a.dval, expand(args...)...)
@inline expand(a::DuplicatedNoNeed, args::Vararg{Any, N}) where N =
(typeof(a), a.val, a.dval, expand(args...)...)
@inline expand(a::BatchDuplicated, args...) =
@inline expand(a::BatchDuplicated, args::Vararg{Any, N}) where N =
(typeof(a), a.val, a.dval..., expand(args...)...)
@inline expand(a::BatchDuplicatedNoNeed, args...) =
@inline expand(a::BatchDuplicatedNoNeed, args::Vararg{Any, N}) where N =
(typeof(a), a.val, a.dval..., expand(args...)...)

# Contract expanded Enzyme activity and arguments into corresponding structs
@inline contract() = ()
@inline contract(::Type{Const{T}}, a, args...) where {T} = (Const(a), contract(args...)...)
@inline contract(::Type{Duplicated{T}}, a, b, args...) where {T} =
@inline contract(::Type{Const{T}}, a, args::Vararg{Any, N}) where {T,N} = (Const(a), contract(args...)...)
@inline contract(::Type{Duplicated{T}}, a, b, args::Vararg{Any, N}) where {T,N} =
(Duplicated(a, b), contract(args...)...)
@inline contract(::Type{DuplicatedNoNeed{T}}, a, b, args...) where {T} =
@inline contract(::Type{DuplicatedNoNeed{T}}, a, b, args::Vararg{Any, N}) where {T, N} =
(DuplicatedNoNeed(a, b), contract(args...)...)

function metaf(fn, args::Vararg{Any, N}) where N
Expand Down Expand Up @@ -77,7 +77,7 @@ function EnzymeCore.EnzymeRules.forward(
Type{ET},
p,
(
ntuple(Val(EnzymeCore.batch_size(ET))) do _
ntuple(Val(EnzymeCore.batch_size(arg))) do _
p
end
)...,
Expand Down

0 comments on commit b85f46e

Please sign in to comment.