From 1e3df24dc91cc77ada2e3847f8281a8fa787b7ad Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Mar 2023 07:14:17 -0500 Subject: [PATCH] Ensure all fft-like functions fallback to version with region when region not provided (#84) * Ensure all fft-like functions fallback to version with region when region not provided * Add testset for default dims * Add tests for complex float promotion * Test complex float promotion for fft,ifft,bfft too --- src/definitions.jl | 12 ++++++------ test/runtests.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 4532650..1cf542b 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -59,7 +59,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x) for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray) = (y = to1(x); $pf(y) * y) + $f(x::AbstractArray) = $f(x, 1:ndims(x)) $f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y) $pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...)) end @@ -207,9 +207,9 @@ bfft! for f in (:fft, :bfft, :ifft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...) end end @@ -297,7 +297,7 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) = for f in (:brfft, :irfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray, d::Integer) = $pf(x, d) * x + $f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x)) $f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x $pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...) end @@ -305,8 +305,8 @@ end for f in (:brfft, :irfft) @eval begin - $f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region) + $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region) end end diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..9cb528a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -213,6 +213,33 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +# Test that dims defaults to 1:ndims for fft-like functions +@testset "Default dims" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + @test fft(x) ≈ fft(x, 1:N) + @test ifft(x) ≈ ifft(x, 1:N) + @test bfft(x) ≈ bfft(x, 1:N) + @test rfft(x) ≈ rfft(x, 1:N) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(x, d, 1:N) + @test brfft(x, d) ≈ brfft(x, d, 1:N) + end +end + +@testset "Complex float promotion" begin + for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5)) + N = ndims(x) + @test fft(x) ≈ fft(complex.(x)) ≈ fft(complex.(float.(x))) + @test ifft(x) ≈ ifft(complex.(x)) ≈ ifft(complex.(float.(x))) + @test bfft(x) ≈ bfft(complex.(x)) ≈ bfft(complex.(float.(x))) + d = 2 * size(x, 1) - 1 + @test irfft(x, d) ≈ irfft(complex.(x), d) ≈ irfft(complex.(float.(x)), d) + @test brfft(x, d) ≈ brfft(complex.(x), d) ≈ brfft(complex.(float.(x)), d) + end +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5))