Skip to content

Commit

Permalink
Ensure all fft-like functions fallback to version with region when re…
Browse files Browse the repository at this point in the history
…gion not provided (#84)

* Ensure all fft-like functions fallback to version with region when
region not provided

* Add testset for default dims

* Add tests for complex float promotion

* Test complex float promotion for fft,ifft,bfft too
  • Loading branch information
gaurav-arya authored Mar 8, 2023
1 parent b2dd69c commit 1e3df24
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ _to1(::Tuple, x) = copy1(eltype(x), x)
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray) = (y = to1(x); $pf(y) * y)
$f(x::AbstractArray) = $f(x, 1:ndims(x))
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
end
Expand Down Expand Up @@ -207,9 +207,9 @@ bfft!
for f in (:fft, :bfft, :ifft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region)
$f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
end
end
Expand Down Expand Up @@ -297,16 +297,16 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
for f in (:brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, d::Integer) = $pf(x, d) * x
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
end
end

for f in (:brfft, :irfft)
@eval begin
$f(x::AbstractArray{<:Real}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region=1:ndims(x)) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
end
end

Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,33 @@ end
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

# Test that dims defaults to 1:ndims for fft-like functions
@testset "Default dims" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
@test fft(x) fft(x, 1:N)
@test ifft(x) ifft(x, 1:N)
@test bfft(x) bfft(x, 1:N)
@test rfft(x) rfft(x, 1:N)
d = 2 * size(x, 1) - 1
@test irfft(x, d) irfft(x, d, 1:N)
@test brfft(x, d) brfft(x, d, 1:N)
end
end

@testset "Complex float promotion" begin
for x in (rand(-5:5, 3), rand(-5:5, 3, 4), rand(-5:5, 3, 4, 5))
N = ndims(x)
@test fft(x) fft(complex.(x)) fft(complex.(float.(x)))
@test ifft(x) ifft(complex.(x)) ifft(complex.(float.(x)))
@test bfft(x) bfft(complex.(x)) bfft(complex.(float.(x)))
d = 2 * size(x, 1) - 1
@test irfft(x, d) irfft(complex.(x), d) irfft(complex.(float.(x)), d)
@test brfft(x, d) brfft(complex.(x), d) brfft(complex.(float.(x)), d)
end
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand Down

0 comments on commit 1e3df24

Please sign in to comment.