diff --git a/Project.toml b/Project.toml index 498fac8c..619ac895 100644 --- a/Project.toml +++ b/Project.toml @@ -19,9 +19,10 @@ julia = "^1.0" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"] diff --git a/README.md b/README.md index fedb8c23..149524fd 100644 --- a/README.md +++ b/README.md @@ -16,26 +16,5 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)` ## Developer information -To define a new FFT implementation in your own module, you should +To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation). -* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`. - This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the - inverse plan. - -* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). - -* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to - 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. - -* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method. - This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs. - -* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the - inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`. - -* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. - -The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\) x_j$ -for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with -$\exp\(+2 \pi i \cdot \frac{j k}{n}\)$. diff --git a/docs/Project.toml b/docs/Project.toml index ed025f5a..4ca9eda1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] diff --git a/docs/src/api.md b/docs/src/api.md index 5d8316b2..bb3b8492 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftdims +Base.adjoint AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 632a6026..7367fd4c 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -11,16 +11,31 @@ The following packages extend the functionality provided by AbstractFFTs: ## Defining a new implementation -Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or -`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support -pre-allocated output arrays. -We don't define `*` in terms of `mul!` generically here, however, because -of subtleties for in-place and real FFT plans. - -To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes -to have a `pinv::Plan` field, which caches the inverse plan, and which should be -initially undefined. -They should also implement `plan_inv(p)` to construct the inverse of a plan `p`. - -Implementations only need to provide the unnormalized backwards FFT, -similar to FFTW, and we do the scaling generically to get the inverse FFT. +To define a new FFT implementation in your own module, you should + +* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`. + This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the + inverse plan. + +* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). + +* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`. + +* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method. + This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs. + +* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the + inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`. + Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically + to get the inverse FFT. + +* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. + +* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return: + * `AbstractFFTs.NoProjectionStyle()`, + * `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref), + * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. + +The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of +length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index d58f5fab..5ab5d2ee 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -159,4 +159,54 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) return y, ifftshift_pullback end +# plans +function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end + Δy = P * Δx + return y, Δy +end +function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end + project_x = ChainRulesCore.ProjectTo(x) + Pt = P' + function mul_plan_pullback(ȳ) + x̄ = project_x(Pt * ȳ) + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄ + end + return y, mul_plan_pullback +end + +function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end + Δy = P * Δx .+ (ΔP.scale / P.scale) .* y + return y, Δy +end +function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end + Pt = P' + scale = P.scale + project_x = ChainRulesCore.ProjectTo(x) + project_scale = ChainRulesCore.ProjectTo(scale) + function mul_scaledplan_pullback(ȳ) + x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ)) + scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale))) + plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent) + return ChainRulesCore.NoTangent(), plan_tangent, x̄ + end + return y, mul_scaledplan_pullback +end + end # module + diff --git a/src/definitions.jl b/src/definitions.jl index 1cf542b2..4ec176eb 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T # size(p) should return the size of the input array for p size(p::Plan, d) = size(p)[d] +output_size(p::Plan, d) = output_size(p)[d] ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int @@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale) ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) +output_size(p::ScaledPlan) = output_size(p.p) fftdims(p::ScaledPlan) = fftdims(p.p) @@ -578,3 +580,80 @@ Pre-plan an optimized real-input unnormalized transform, similar to the same as for [`brfft`](@ref). """ plan_brfft + +############################################################################## + +struct NoProjectionStyle end +struct RealProjectionStyle end +struct RealInverseProjectionStyle + dim::Int +end +const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} + +output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) +_output_size(p::Plan, ::NoProjectionStyle) = size(p) +_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) +_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) + +struct AdjointPlan{T,P<:Plan} <: Plan{T} + p::P + AdjointPlan{T,P}(p) where {T,P} = new(p) +end + +""" + (p::Plan)' + adjoint(p::Plan) + +Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of +the original plan. Note that this differs from the corresponding backwards plan in the case of real +FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). + +!!! note + Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, + coverage of `Base.adjoint` in downstream implementations may be limited. +""" +Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) +Base.adjoint(p::AdjointPlan) = p.p +# always have AdjointPlan inside ScaledPlan. +Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) + +size(p::AdjointPlan) = output_size(p.p) +output_size(p::AdjointPlan) = size(p.p) + +Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} + dims = fftdims(p.p) + N = normalization(T, size(p.p), dims) + return (p.p \ x) / N +end + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} + dims = fftdims(p.p) + N = normalization(T, size(p.p), dims) + halfdim = first(dims) + d = size(p.p, halfdim) + n = output_size(p.p, halfdim) + scale = reshape( + [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], + ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) + ) + return p.p \ (x ./ convert(typeof(x), scale)) +end + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} + dims = fftdims(p.p) + N = normalization(real(T), output_size(p.p), dims) + halfdim = first(dims) + n = size(p.p, halfdim) + d = output_size(p.p, halfdim) + scale = reshape( + [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], + ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) + ) + return (convert(typeof(x), scale) ./ N) .* (p.p \ x) +end + +# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). +plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) +inv(p::AdjointPlan) = adjoint(inv(p.p)) diff --git a/test/runtests.jl b/test/runtests.jl index 9cb528ac..c5f0659b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,10 @@ # This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license using AbstractFFTs -using AbstractFFTs: Plan +using AbstractFFTs: Plan, ScaledPlan using ChainRulesTestUtils +using FiniteDifferences +import ChainRulesCore using LinearAlgebra using Random @@ -66,6 +68,13 @@ end @test fftdims(P) == dims end + # in-place plan + P = plan_fft!(x, dims) + @test eltype(P) === ComplexF64 + xc64 = ComplexF64.(x) + @test P * xc64 ≈ fftw_fft + @test xc64 ≈ fftw_fft + fftw_bfft = complex.(size(x, dims) .* x) @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft P = plan_bfft(x, dims) @@ -73,6 +82,13 @@ end @test P \ (P * y) ≈ y @test fftdims(P) == dims + # in-place plan + P = plan_bfft!(x, dims) + @test eltype(P) === ComplexF64 + yc64 = ComplexF64.(y) + @test P * yc64 ≈ fftw_bfft + @test yc64 ≈ fftw_bfft + fftw_ifft = complex.(x) @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft # test plan_ifft and also inv and plan_inv of plan_fft, which should all give @@ -84,6 +100,13 @@ end @test fftdims(P) == dims end + # in-place plan + P = plan_ifft!(x, dims) + @test eltype(P) === ComplexF64 + yc64 = ComplexF64.(y) + @test P * yc64 ≈ fftw_ifft + @test yc64 ≈ fftw_ifft + # real FFT fftw_rfft = fftw_fft[ (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., @@ -213,6 +236,90 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +@testset "output size" begin + @testset "complex fft output size" begin + for x_shape in ((3,), (3, 4), (3, 4, 5)) + N = length(x_shape) + real_x = randn(x_shape) + complex_x = randn(ComplexF64, x_shape) + for x in (real_x, complex_x) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test @inferred(AbstractFFTs.output_size(P)) == size(x) + @test AbstractFFTs.output_size(P') == size(x) + Pinv = plan_ifft(x) + @test AbstractFFTs.output_size(Pinv) == size(x) + @test AbstractFFTs.output_size(Pinv') == size(x) + end + end + end + end + @testset "real fft output size" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths + N = ndims(x) + for dims in unique((1, 1:N, N)) + P = plan_rfft(x, dims) + Px_sz = size(P * x) + @test AbstractFFTs.output_size(P) == Px_sz + @test AbstractFFTs.output_size(P') == size(x) + y = randn(ComplexF64, Px_sz) + Pinv = plan_irfft(y, size(x)[first(dims)], dims) + @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) + @test AbstractFFTs.output_size(Pinv') == size(y) + end + end + end +end + +@testset "adjoint" begin + @testset "complex fft adjoint" begin + for x_shape in ((3,), (3, 4), (3, 4, 5)) + N = length(x_shape) + real_x = randn(x_shape) + complex_x = randn(ComplexF64, x_shape) + y = randn(ComplexF64, x_shape) + for x in (real_x, complex_x) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test (P')' === P # test adjoint of adjoint + @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint + @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint + @test dot(y, P \ x) ≈ dot(P' \ y, x) # test inv of adjoint + @test dot(y, P \ x) ≈ dot(AbstractFFTs.plan_inv(P') * y, x) # test plan_inv of adjoint + Pinv = plan_ifft(y) + @test (Pinv')' * y == Pinv * y + @test size(Pinv') == AbstractFFTs.output_size(Pinv) + @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) + @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + @test dot(x, Pinv \ y) ≈ dot(AbstractFFTs.plan_inv(Pinv') * x, y) + @test_throws MethodError mul!(x, P', y) + end + end + end + end + @testset "real fft adjoint" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths + N = ndims(x) + for dims in unique((1, 1:N, N)) + P = plan_rfft(x, dims) + y = randn(ComplexF64, size(P * x)) + @test (P')' * x == P * x + @test size(P') == AbstractFFTs.output_size(P) + @test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x) + @test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) ≈ dot(P \ y, x) + @test dot(real.(y), real.(AbstractFFTs.plan_inv(P') * x)) + + dot(imag.(y), imag.(AbstractFFTs.plan_inv(P') * x)) ≈ dot(P \ y, x) + Pinv = plan_irfft(y, size(x)[first(dims)], dims) + @test (Pinv')' * y == Pinv * y + @test size(Pinv') == AbstractFFTs.output_size(Pinv) + @test dot(x, Pinv * y) ≈ dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x)) + @test dot(x, Pinv' \ y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) + @test dot(x, AbstractFFTs.plan_inv(Pinv') * y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) + end + end + end +end + # Test that dims defaults to 1:ndims for fft-like functions @testset "Default dims" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) @@ -261,20 +368,47 @@ end end @testset "fft" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - N = ndims(x) - complex_x = complex.(x) + # Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256 + InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan} + function FiniteDifferences.to_vec(x::InnerPlan) + function FFTPlan_from_vec(x_vec::Vector) + return x + end + return Bool[], FFTPlan_from_vec + end + ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true + ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent() + + for x_shape in ((2,), (2, 3), (3, 4, 5)) + N = length(x_shape) + x = randn(x_shape) + complex_x = randn(ComplexF64, x_shape) + Δ = (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesTestUtils.rand_tangent(complex_x)) for dims in unique((1, 1:N, N)) + # fft, ifft, bfft for f in (fft, ifft, bfft) test_frule(f, x, dims) test_rrule(f, x, dims) test_frule(f, complex_x, dims) test_rrule(f, complex_x, dims) end + for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) + test_frule(*, pf(x, dims), x) + test_rrule(*, pf(x, dims), x) + test_frule(*, pf(complex_x, dims), complex_x) + test_rrule(*, pf(complex_x, dims), complex_x) + + @test_throws ArgumentError ChainRulesCore.frule(Δ, *, pf!(complex_x, dims), complex_x) + @test_throws ArgumentError ChainRulesCore.rrule(*, pf!(complex_x, dims), complex_x) + end + # rfft test_frule(rfft, x, dims) test_rrule(rfft, x, dims) + test_frule(*, plan_rfft(x, dims), x) + test_rrule(*, plan_rfft(x, dims), x) + # irfft, brfft for f in (irfft, brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) test_frule(f, x, d, dims) @@ -283,6 +417,12 @@ end test_rrule(f, complex_x, d, dims) end end + for pf in (plan_irfft, plan_brfft) + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) + test_frule(*, pf(complex_x, d, dims), complex_x) + test_rrule(*, pf(complex_x, d, dims), complex_x) + end + end end end end diff --git a/test/testplans.jl b/test/testplans.jl index 31609a9b..09b3f671 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -1,18 +1,18 @@ -mutable struct TestPlan{T,N} <: Plan{T} - region +mutable struct TestPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{T} - function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function TestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} + return new{T,N,G}(region, sz) end end -mutable struct InverseTestPlan{T,N} <: Plan{T} - region +mutable struct InverseTestPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{T} - function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function InverseTestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} + return new{T,N,G}(region, sz) end end @@ -21,6 +21,9 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N Base.size(p::InverseTestPlan) = p.sz Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N +AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle() +AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle() + function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) end @@ -89,24 +92,27 @@ end Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) -mutable struct TestRPlan{T,N} <: Plan{T} - region +mutable struct TestRPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{Complex{T}} - TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz) + TestRPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} = new{T,N,G}(region, sz) end -mutable struct InverseTestRPlan{T,N} <: Plan{Complex{T}} +mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}} d::Int - region + region::G sz::NTuple{N,Int} pinv::Plan{T} - function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N} + function InverseTestRPlan{T}(d::Int, region::G, sz::NTuple{N,Int}) where {T,N,G} sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions") - return new{T,N}(d, region, sz) + return new{T,N,G}(d, region, sz) end end +AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() +AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) + function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) end @@ -226,3 +232,25 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray) return y end + +# In-place plans +# (simple wrapper of out-of-place plans that does not support inverses) +struct InplaceTestPlan{T,P<:Plan{T}} <: Plan{T} + plan::P +end + +Base.size(p::InplaceTestPlan) = size(p.plan) +Base.ndims(p::InplaceTestPlan) = ndims(p.plan) +AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan) + +function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) + return InplaceTestPlan(plan_fft(x, region; kwargs...)) +end +function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...) + return InplaceTestPlan(plan_bfft(x, region; kwargs...)) +end + +function LinearAlgebra.mul!(y::AbstractArray, p::InplaceTestPlan, x::AbstractArray) + return mul!(y, p.plan, x) +end +Base.:*(p::InplaceTestPlan, x::AbstractArray) = copyto!(x, p.plan * x)