diff --git a/Project.toml b/Project.toml index 86ce0d0..c02f119 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" +version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -9,10 +9,12 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +AbstractFFTsForwardDiffExt = "ForwardDiff" AbstractFFTsTestExt = "Test" [compat] @@ -20,6 +22,7 @@ Aqua = "0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" FiniteDifferences = "0.12" +ForwardDiff = "0.10" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" Test = "<0.0.1, 1" @@ -31,9 +34,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"] +test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "ForwardDiff", "Random", "Test", "Unitful"] diff --git a/ext/AbstractFFTsForwardDiffExt.jl b/ext/AbstractFFTsForwardDiffExt.jl new file mode 100644 index 0000000..029e09d --- /dev/null +++ b/ext/AbstractFFTsForwardDiffExt.jl @@ -0,0 +1,59 @@ +module AbstractFFTsForwardDiffExt + +using AbstractFFTs +using AbstractFFTs.LinearAlgebra +import ForwardDiff +import ForwardDiff: Dual +import AbstractFFTs: Plan, mul!, dualplan, dual2array + + +AbstractFFTs._fftfloat(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,AbstractFFTs._fftfloat(V),N} + +dual2array(x::StridedArray{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::StridedArray{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::StridedArray{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::StridedArray{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) + + +######## +# DualPlan +# represents a plan acting on dual numbers. We wrap a plan acting on a higher dimensional tensor +# as an array of duals can be reinterpreted as a higher dimensional array. +# This allows standard FFTW plans to act on arrays of duals. +##### +struct DualPlan{T,P} <: Plan{T} + p::P + DualPlan{T,P}(p) where {T,P} = new(p) +end + +DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{T}) where {Tag,T<:Real,V,N} = DualPlan{Dual{Tag,T,N},typeof(p)}(p) +DualPlan(::Type{Dual{Tag,V,N}}, p::Plan{Complex{T}}) where {Tag,T<:Real,V,N} = DualPlan{Complex{Dual{Tag,T,N}},typeof(p)}(p) +dualplan(D, p) = DualPlan(D, p) +Base.size(p::DualPlan) = Base.tail(size(p.p)) +Base.:*(p::DualPlan{DT}, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) +Base.:*(p::DualPlan{Complex{DT}}, x::AbstractArray{Complex{DT}}) where DT<:Dual = array2dual(DT, p.p * dual2array(x)) + +function LinearAlgebra.mul!(y::AbstractArray{<:Dual}, p::DualPlan, x::AbstractArray{<:Dual}) + LinearAlgebra.mul!(dual2array(y), p.p, dual2array(x)) # even though `Dual` are immutable, when in an `Array` they can be modified. + y +end + +function LinearAlgebra.mul!(y::AbstractArray{<:Complex{<:Dual}}, p::DualPlan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) + copyto!(y, p*x) # Complex duals cannot be reinterpret in-place +end + + +for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) + @eval begin + AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...)) + AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...)) + end +end + + +for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? + @eval AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims; kwds...)) +end + + +end # module \ No newline at end of file diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 3225916..52538bf 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -8,6 +8,10 @@ export fft, ifft, bfft, fft!, ifft!, bfft!, include("definitions.jl") include("TestUtils.jl") +# Create function used by multiple extension as loading order is not guaranteed +function dualplan end +function dual2array end + if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") include("../ext/AbstractFFTsTestExt.jl") diff --git a/test/abstractfftsforwarddiff.jl b/test/abstractfftsforwarddiff.jl new file mode 100644 index 0000000..8a2b3e7 --- /dev/null +++ b/test/abstractfftsforwarddiff.jl @@ -0,0 +1,68 @@ +using AbstractFFTs +using ForwardDiff +using Test +using ForwardDiff: Dual, partials, value + +# Needed until https://github.com/JuliaDiff/ForwardDiff.jl/pull/732 is merged +complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k) + +@testset "ForwardDiff extension tests" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + c1 = Dual.(1:4.0, 2:5, 3:6) + im*Dual.(2:5.0, 3:6, 3:6) + + @test AbstractFFTs.complexfloat(x1)[1] === Dual(1.0, 2.0, 3.0) + 0im + @test AbstractFFTs.realfloat(x1)[1] === Dual(1.0, 2.0, 3.0) + + @test fft(x1, 1)[1] isa Complex{<:Dual} + @test plan_fft(x1, 1) * x1 == fft(x1, 1) + @test size(plan_fft(x1,1)) == (4,) + + @testset "$f" for f in (fft, ifft, rfft, bfft) + @test value.(f(x1)) == f(value.(x1)) + @test complexpartials.(f(x1), 1) == f(partials.(x1, 1)) + @test complexpartials.(f(x1), 2) == f(partials.(x1, 2)) + end + + @test ifft(fft(x1)) ≈ x1 + @test irfft(rfft(x1), length(x1)) ≈ x1 + @test brfft(rfft(x1), length(x1)) ≈ 4x1 + + f = x -> real(fft([x; 0; 0])[1]) + @test ForwardDiff.derivative(f,0.1) ≈ 1 + + r = x -> real(rfft([x; 0; 0])[1]) + @test ForwardDiff.derivative(r,0.1) ≈ 1 + + + n = 100 + θ = range(0,2π; length=n+1)[1:end-1] + # emperical from Mathematical + @test ForwardDiff.derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 + + @testset "matrix" begin + A = x1 * (1:10)' + @test value.(fft(A)) == fft(value.(A)) + @test complexpartials.(fft(A), 1) == fft(partials.(A, 1)) + @test complexpartials.(fft(A), 2) == fft(partials.(A, 2)) + + @test value.(fft(A, 1)) == fft(value.(A), 1) + @test complexpartials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) + @test complexpartials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) + + @test value.(fft(A, 2)) == fft(value.(A), 2) + @test complexpartials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) + @test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) + end + + @testset "complex" begin + @test fft(c1) ≈ fft(real(c1)) + im*fft(imag(c1)) + dest = similar(c1) + @test mul!(dest, plan_fft(x1), x1) == fft(x1) == dest + @test mul!(dest, plan_fft(c1), c1) == fft(c1) == dest + + C = c1 * ((1:10) .+ im*(2:11))' + @test fft(C) ≈ fft(real(C)) + im*fft(imag(C)) + dest = similar(C) + @test mul!(dest, plan_fft(C), C) == fft(C) == dest + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0560174..ceba516 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,7 @@ Random.seed!(1234) # Load example plan implementation. include("TestPlans.jl") -# Run interface tests for TestPlans +# Run interface tests for TestPlans AbstractFFTs.TestUtils.test_complex_ffts(Array) AbstractFFTs.TestUtils.test_real_ffts(Array) @@ -180,17 +180,17 @@ end p0 = plan_fft(zeros(ComplexF64, 3)) p = TestPlans.WrapperTestPlan(p0) u = rand(ComplexF64, 3) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u # rfft p0 = plan_rfft(zeros(3)) p = TestPlans.WrapperTestPlan(p0) u = rand(ComplexF64, 2) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u # brfft p0 = plan_brfft(zeros(ComplexF64, 3), 5) p = TestPlans.WrapperTestPlan(p0) u = rand(Float64, 5) - @test p' * u ≈ p0' * u + @test p' * u ≈ p0' * u end @testset "ChainRules" begin @@ -238,7 +238,7 @@ end test_frule(f, complex_x, dims) test_rrule(f, complex_x, dims) end - for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) + for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) test_frule(*, pf(x, dims), x) test_rrule(*, pf(x, dims), x) test_frule(*, pf(complex_x, dims), complex_x) @@ -248,7 +248,7 @@ end @test_throws ArgumentError ChainRulesCore.rrule(*, pf!(complex_x, dims), complex_x) end - # rfft + # rfft test_frule(rfft, x, dims) test_rrule(rfft, x, dims) test_frule(*, plan_rfft(x, dims), x) @@ -266,11 +266,14 @@ end for pf in (plan_irfft, plan_brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) test_frule(*, pf(complex_x, d, dims), complex_x) - test_rrule(*, pf(complex_x, d, dims), complex_x) + test_rrule(*, pf(complex_x, d, dims), complex_x) end end end end end end - + +if isdefined(Base, :get_extension) + include("abstractfftsforwarddiff.jl") +end \ No newline at end of file