Skip to content

Commit

Permalink
Add option to test with FFTW backend
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Aug 18, 2022
1 parent 03ef58b commit 4347fa9
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 231 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ jobs:
- windows-latest
arch:
- x64
group:
- TestPlans
- FFTW
exclude:
- version: '1.0'
group: FFTW
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand All @@ -40,7 +46,10 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
file: lcov.info
flag-name: group-${{ matrix.group }} # unique name for coverage report of each group
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ julia = "^1.0"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesTestUtils", "FFTW", "Random", "Test", "Unitful"]
7 changes: 7 additions & 0 deletions test/testplans.jl → test/TestPlans.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module TestPlans

using AbstractFFTs
using AbstractFFTs: Plan

mutable struct TestPlan{T,N} <: Plan{T}
region
sz::NTuple{N,Int}
Expand Down Expand Up @@ -226,3 +231,5 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray)

return y
end

end
239 changes: 9 additions & 230 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

using AbstractFFTs
using AbstractFFTs: Plan
using ChainRulesTestUtils
Expand All @@ -12,235 +10,16 @@ import Unitful

Random.seed!(1234)

include("testplans.jl")

@testset "rfft sizes" begin
A = rand(11, 10)
@test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10)
@test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6)
A1 = rand(6, 10); A2 = rand(11, 6)
@test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10)
@test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10)
@test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2)
end

@testset "Custom Plan" begin
# DFT along last dimension, results computed using FFTW
for (x, fftw_fft) in (
(collect(1:7),
[28.0 + 0.0im,
-3.5 + 7.267824888003178im,
-3.5 + 2.7911568610884143im,
-3.5 + 0.7988521603655248im,
-3.5 - 0.7988521603655248im,
-3.5 - 2.7911568610884143im,
-3.5 - 7.267824888003178im]),
(collect(1:8),
[36.0 + 0.0im,
-4.0 + 9.65685424949238im,
-4.0 + 4.0im,
-4.0 + 1.6568542494923806im,
-4.0 + 0.0im,
-4.0 - 1.6568542494923806im,
-4.0 - 4.0im,
-4.0 - 9.65685424949238im]),
(collect(reshape(1:8, 2, 4)),
[16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
(collect(reshape(1:9, 3, 3)),
[12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
)
# FFT
dims = ndims(x)
y = AbstractFFTs.fft(x, dims)
@test y fftw_fft
P = plan_fft(x, dims)
@test eltype(P) === ComplexF64
@test P * x fftw_fft
@test P \ (P * x) x
@test fftdims(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) fftw_bfft
P = plan_bfft(x, dims)
@test P * y fftw_bfft
@test P \ (P * y) y
@test fftdims(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) fftw_ifft
P = plan_ifft(x, dims)
@test P * y fftw_ifft
@test P \ (P * y) y
@test fftdims(P) == dims

# real FFT
fftw_rfft = fftw_fft[
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
]
ry = AbstractFFTs.rfft(x, dims)
@test ry fftw_rfft
P = plan_rfft(x, dims)
@test eltype(P) === Int
@test P * x fftw_rfft
@test P \ (P * x) x
@test fftdims(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry fftw_brfft
@test P \ (P * ry) ry
@test fftdims(P) == dims
const GROUP = get(ENV, "GROUP", "All")

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry fftw_irfft
@test P \ (P * ry) ry
@test fftdims(P) == dims
end
end

@testset "Shift functions" begin
@test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2]
@test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6])) == [6 4 5; 3 1 2]
a = [0 0 0]
b = [0, 0, 0]
c = [0 0 0; 0 0 0]
@test (AbstractFFTs.fftshift!(a, [1 2 3]); a == [3 1 2])
@test (AbstractFFTs.fftshift!(b, [1, 2, 3]); b == [3, 1, 2])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6]); c == [6 4 5; 3 1 2])

@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2))) == [6 4 5; 3 1 2]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2)) == [6 4 5; 3 1 2]
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [6 4 5; 3 1 2])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1:2); c == [6 4 5; 3 1 2])

@test @inferred(AbstractFFTs.ifftshift([1 2 3])) == [2 3 1]
@test @inferred(AbstractFFTs.ifftshift([1, 2, 3])) == [2, 3, 1]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6])) == [5 6 4; 2 3 1]
@test (AbstractFFTs.ifftshift!(a, [1 2 3]); a == [2 3 1])
@test (AbstractFFTs.ifftshift!(b, [1, 2, 3]); b == [2, 3, 1])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6]); c == [5 6 4; 2 3 1])
include("TestPlans.jl")
include("testfft.jl")

@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2))) == [5 6 4; 2 3 1]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2)) == [5 6 4; 2 3 1]
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [5 6 4; 2 3 1])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1:2); c == [5 6 4; 2 3 1])
if GROUP == "All" || GROUP == "TestPlans"
using .TestPlans
testfft()
elseif GROUP == "All" || GROUP == "FFTW" # integration test with FFTW
using FFTW
testfft()
end

@testset "FFT Frequencies" begin
@test fftfreq(8) isa Frequencies
@test copy(fftfreq(8)) isa Frequencies

# N even
@test fftfreq(8) == [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125]
@test rfftfreq(8) == [0.0, 0.125, 0.25, 0.375, 0.5]
@test fftshift(fftfreq(8)) == -0.5:0.125:0.375

# N odd
@test fftfreq(5) == [0.0, 0.2, 0.4, -0.4, -0.2]
@test rfftfreq(5) == [0.0, 0.2, 0.4]
@test fftshift(fftfreq(5)) == -0.4:0.2:0.4

# Sampling Frequency
@test fftfreq(5, 2) == [0.0, 0.4, 0.8, -0.8, -0.4]
# <:Number type compatibility
@test eltype(fftfreq(5, ComplexF64(2))) == ComplexF64

@test_throws ArgumentError Frequencies(12, 10, 1)

@testset "scaling" begin
@test fftfreq(4, 1) * 2 === fftfreq(4, 2)
@test fftfreq(4, 1) .* 2 === fftfreq(4, 2)
@test 2 * fftfreq(4, 1) === fftfreq(4, 2)
@test 2 .* fftfreq(4, 1) === fftfreq(4, 2)

@test fftfreq(4, 1) / 2 === fftfreq(4, 1/2)
@test fftfreq(4, 1) ./ 2 === fftfreq(4, 1/2)

@test 2 \ fftfreq(4, 1) === fftfreq(4, 1/2)
@test 2 .\ fftfreq(4, 1) === fftfreq(4, 1/2)
end

@testset "extrema" begin
function check_extrema(freqs)
for f in [minimum, maximum, extrema]
@test f(freqs) == f(collect(freqs)) == f(fftshift(freqs))
end
end
for f in (fftfreq, rfftfreq), n in (8, 9), multiplier in (2, 1/3, -1/7, 1.0*Unitful.mm)
freqs = f(n, multiplier)
check_extrema(freqs)
end
end
end

@testset "normalization" begin
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
for dims in ((), 1, 2, (1,2), 1:2)
any(d > ndims(x) for d in dims) && continue

# type inference checks of `rrule` fail on old Julia versions
# for higher-dimensional arrays:
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"

test_frule(AbstractFFTs.fftshift, x, dims)
test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred)

test_frule(AbstractFFTs.ifftshift, x, dims)
test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred)
end
end
end

@testset "fft" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
for dims in unique((1, 1:N, N))
for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
end

test_frule(rfft, x, dims)
test_rrule(rfft, x, dims)

for f in (irfft, brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(f, x, d, dims)
test_rrule(f, x, d, dims)
test_frule(f, complex_x, d, dims)
test_rrule(f, complex_x, d, dims)
end
end
end
end
end
end
Loading

0 comments on commit 4347fa9

Please sign in to comment.