diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index 5ab5d2e..a26bc59 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -30,16 +30,15 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) - scale = reshape( + scale = convert(typeof(y), reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) ybar = ChainRulesCore.unthunk(ȳ) - _scale = convert(typeof(ybar),scale) - x̄ = project_x(brfft(ybar ./ _scale, d, dims)) + x̄ = project_x(brfft(ybar ./ scale, d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -74,16 +73,15 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN - scale = reshape( + scale = convert(typeof(y), reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) ybar = ChainRulesCore.unthunk(ȳ) - _scale = convert(typeof(ybar),scale) - x̄ = project_x(_scale .* rfft(real.(ybar), dims)) + x̄ = project_x(scale .* rfft(real.(ybar), dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback @@ -115,10 +113,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - scale = reshape( + scale = convert(typeof(y), reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ)