-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chain rules for FFT plans via AdjointPlans #67
Changes from all commits
ad71816
c91ad50
061eef9
497ff4d
5d5c06c
ef84edf
d7ff394
aa8e575
9d99886
ac7c78c
8474141
769c090
3ed83df
552d49f
1e9ece2
8ddfa97
87758c8
09b8b38
d967aa2
2a423e2
eedba14
25bb86b
2a2d685
fe3b06a
266c88f
403ce47
e137ae3
e601347
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,4 +159,54 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) | |
return y, ifftshift_pullback | ||
end | ||
|
||
# plans | ||
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) | ||
y = P * x | ||
if Base.mightalias(y, x) | ||
throw(ArgumentError("differentiation rules are not supported for in-place plans")) | ||
end | ||
Δy = P * Δx | ||
return y, Δy | ||
end | ||
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) | ||
y = P * x | ||
if Base.mightalias(y, x) | ||
throw(ArgumentError("differentiation rules are not supported for in-place plans")) | ||
end | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
Pt = P' | ||
function mul_plan_pullback(ȳ) | ||
x̄ = project_x(Pt * ȳ) | ||
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄ | ||
end | ||
return y, mul_plan_pullback | ||
end | ||
|
||
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) | ||
y = P * x | ||
if Base.mightalias(y, x) | ||
throw(ArgumentError("differentiation rules are not supported for in-place plans")) | ||
end | ||
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it inconsistent at all that here we use the tangent of the scale part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm it seems plans are assumed to be constant (AFAICT from the initial version of the PR) but the scaling might change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess there's probably never a good reason a user would want (co)tangents for a |
||
return y, Δy | ||
end | ||
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) | ||
y = P * x | ||
if Base.mightalias(y, x) | ||
throw(ArgumentError("differentiation rules are not supported for in-place plans")) | ||
end | ||
Pt = P' | ||
scale = P.scale | ||
project_x = ChainRulesCore.ProjectTo(x) | ||
project_scale = ChainRulesCore.ProjectTo(scale) | ||
function mul_scaledplan_pullback(ȳ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that would require the FFT plan to support fused cache = get_cache(plan)
copy!(cache, y)
mul!(y, plan, x)
axpby!(b, cache, a, y) Feels out of scope for this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is my understanding that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I quickly checked the Julia repo, and there are a few open issues that show that at least in practice such a guarantee does not exist: https://github.com/JuliaLang/julia/issues/49332 JuliaLang/julia#46865 Arguably these are just bugs but on the other hand the docstring of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For both cases, the allocation size is independent of the array size indicating that the arrays are not being allocated. Looks like a spurious size tuple allocation to me. Examples: julia> versioninfo()
Julia Version 1.9.1
Commit 147bdf428cd (2023-06-07 08:27 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 8 × Apple M2 https://github.com/JuliaLang/julia/issues/49332 julia> using LinearAlgebra, BenchmarkTools
julia> A = rand(ComplexF64,4,4,1000,1000);
julia> B = similar(A);
julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));
julia> @btime mul!($b,$a,$a); # 4x4 * 4x4
311.283 ns (10 allocations: 608 bytes)
julia> A = rand(ComplexF64,128,128,10,10);
julia> B = similar(A);
julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));
julia> @btime mul!($b,$a,$a); # 128x128 * 128x128
170.542 μs (10 allocations: 608 bytes) julia> N = 5_000;
julia> A = rand(N, N); B = rand(N, N); C = rand(N, N);
julia> @time mul!(C, A, B, true, true);
1.729141 seconds (1 allocation: 16 bytes)
julia> @time mul!(C, A, B);
1.637079 seconds
julia> @time A * B; # allocates N x N array
1.421422 seconds (2 allocations: 190.735 MiB, 0.13% gc time) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was my understanding from skimming through the issues - and why I wrote arguably these could be considered to be bugs. My main point: There are no guarantees in Julia regarding allocation, the language or the JIT-compiler does not enforce any contracts, so it's only possible to document interfaces and trust people to implement them accordingly. But in the case of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do think it makes sense for AbstractFTTs to ultimately support downstream packages implementing either 3-arg or 5-arg |
||
x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ)) | ||
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale))) | ||
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent) | ||
return ChainRulesCore.NoTangent(), plan_tangent, x̄ | ||
end | ||
return y, mul_scaledplan_pullback | ||
end | ||
|
||
end # module | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T | |
|
||
# size(p) should return the size of the input array for p | ||
size(p::Plan, d) = size(p)[d] | ||
output_size(p::Plan, d) = output_size(p)[d] | ||
ndims(p::Plan) = length(size(p)) | ||
length(p::Plan) = prod(size(p))::Int | ||
|
||
|
@@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale) | |
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) | ||
|
||
size(p::ScaledPlan) = size(p.p) | ||
output_size(p::ScaledPlan) = output_size(p.p) | ||
|
||
fftdims(p::ScaledPlan) = fftdims(p.p) | ||
|
||
|
@@ -578,3 +580,80 @@ Pre-plan an optimized real-input unnormalized transform, similar to | |
the same as for [`brfft`](@ref). | ||
""" | ||
plan_brfft | ||
|
||
############################################################################## | ||
|
||
struct NoProjectionStyle end | ||
struct RealProjectionStyle end | ||
struct RealInverseProjectionStyle | ||
dim::Int | ||
end | ||
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} | ||
|
||
output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) | ||
_output_size(p::Plan, ::NoProjectionStyle) = size(p) | ||
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) | ||
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) | ||
|
||
struct AdjointPlan{T,P<:Plan} <: Plan{T} | ||
p::P | ||
AdjointPlan{T,P}(p) where {T,P} = new(p) | ||
end | ||
|
||
""" | ||
(p::Plan)' | ||
adjoint(p::Plan) | ||
|
||
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of | ||
the original plan. Note that this differs from the corresponding backwards plan in the case of real | ||
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). | ||
|
||
!!! note | ||
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, | ||
coverage of `Base.adjoint` in downstream implementations may be limited. | ||
""" | ||
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) | ||
Base.adjoint(p::AdjointPlan) = p.p | ||
# always have AdjointPlan inside ScaledPlan. | ||
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) | ||
|
||
size(p::AdjointPlan) = output_size(p.p) | ||
output_size(p::AdjointPlan) = size(p.p) | ||
|
||
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) | ||
|
||
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} | ||
dims = fftdims(p.p) | ||
N = normalization(T, size(p.p), dims) | ||
return (p.p \ x) / N | ||
end | ||
|
||
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} | ||
dims = fftdims(p.p) | ||
N = normalization(T, size(p.p), dims) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's right, since we should only expect an Also, regarding the use of |
||
halfdim = first(dims) | ||
d = size(p.p, halfdim) | ||
n = output_size(p.p, halfdim) | ||
scale = reshape( | ||
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], | ||
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) | ||
) | ||
return p.p \ (x ./ convert(typeof(x), scale)) | ||
end | ||
|
||
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} | ||
dims = fftdims(p.p) | ||
N = normalization(real(T), output_size(p.p), dims) | ||
halfdim = first(dims) | ||
n = size(p.p, halfdim) | ||
d = output_size(p.p, halfdim) | ||
scale = reshape( | ||
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], | ||
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) | ||
) | ||
return (convert(typeof(x), scale) ./ N) .* (p.p \ x) | ||
end | ||
|
||
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). | ||
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) | ||
inv(p::AdjointPlan) = adjoint(inv(p.p)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add tests that this error (and the others below) are thrown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a bit more involved if we actually want to probe it with FFT plans since currently the test suite does not contain any in-place test plans. I guess the easiest option would be to just re-use the existing out-of-place plans and define in-place updates re-using their implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added tests 🙂