Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain rules for FFT plans via AdjointPlans #67

Merged
merged 28 commits into from
Jul 5, 2023

Conversation

gaurav-arya
Copy link
Contributor

@gaurav-arya gaurav-arya commented Jun 6, 2022

An rfft can be written as PF where F is the n x n Fourier transform and P is a projection operator that removes the redundant information due to conjuagate symmetry. Because of P, the adjoint of real FFTs (real inverse FFTs) require a special scaling before (after) applying the backwards transformation. As discussed in #63 this motivates supporting the Base.adjoint operation for plans to simplify the writing of backward rules for AD.

The following functions must be implemented by backends in order for output_size(p::Plan) and AdjointPlan to work:

  • projection_style(p::Plan) which can either be :none, :real, or :real_inv.
  • irfft_dim(p::Plan), only for those plans with :real_inv projection style, which gives the original length of the halved dimension.

Using the adjoint plan, we can simplify the writing of backwards rules. I test the adjoint plans both directly and indirectly through tests of the rrule's.

NB: The interface has changed since the initial PR message. See the updated implementation docs in the PR for accurate info.

src/definitions.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jun 6, 2022

Codecov Report

Patch coverage: 100.00% and project coverage change: +4.13 🎉

Comparison is base (b5109aa) 87.32% compared to head (e137ae3) 91.45%.

❗ Current head e137ae3 differs from pull request most recent head e601347. Consider uploading reports for the commit e601347 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master      JuliaLang/julia#67      +/-   ##
==========================================
+ Coverage   87.32%   91.45%   +4.13%     
==========================================
  Files           3        3              
  Lines         213      281      +68     
==========================================
+ Hits          186      257      +71     
+ Misses         27       24       -3     
Impacted Files Coverage Δ
ext/AbstractFFTsChainRulesCoreExt.jl 100.00% <100.00%> (ø)
src/definitions.jl 83.33% <100.00%> (+9.29%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

test/runtests.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
@gaurav-arya gaurav-arya force-pushed the adjoint branch 3 times, most recently from c5c3755 to 6c81dfd Compare June 9, 2022 06:37
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
@gaurav-arya gaurav-arya force-pushed the adjoint branch 2 times, most recently from 534ddd4 to af74c54 Compare June 9, 2022 16:42
@gaurav-arya gaurav-arya force-pushed the adjoint branch 2 times, most recently from 40cce00 to 7cba04e Compare July 1, 2022 01:35
@gaurav-arya
Copy link
Contributor Author

This should be ready for another review (with #69 as a dependency)

@gaurav-arya gaurav-arya requested a review from devmotion July 1, 2022 05:06
@gaurav-arya gaurav-arya force-pushed the adjoint branch 4 times, most recently from 7149781 to 675c61a Compare July 14, 2022 19:17
test/runtests.jl Outdated
for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
end
for pf in (plan_fft, plan_ifft, plan_bfft)
test_frule(*, pf(x, dims) ⊢ NoTangent(), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the NoTangent needed here?

Copy link
Contributor Author

@gaurav-arya gaurav-arya Jul 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It tells ChainRulesTestUtils to use a no tangent for the plan, i.e. do not try to differentiate w.r.t. the plan. See https://juliadiff.org/ChainRulesTestUtils.jl/dev/#Specifying-Tangents. It's necessary to manually specify this because with all the caching stuff and plans being mutable, ChainRules gets confused about the structure of Plan and rand_tangent errors

Unfortunately, there is one case where we do want to differentiate w.r.t. plan (as far as I can tell this is the only case), when someone makes a ScaledPlan whose scale depends on the parameter:

using AbstractFFTs
using AbstractFFTs: Plan
using Zygote
include("test/testplans.jl")

julia> function f(x)
           return sum(abs.(P * x))
           end
f (generic function with 1 method)

# correct 
julia> Zygote.gradient(f, [1,2,3])
([-0.732050807568877, 0.9999999999999996, 2.7320508075688776],)

julia> function f(x)
           return sum(abs.(x[1] * P * x))
           end
f (generic function with 1 method)

# silently wrong :(
julia> Zygote.gradient(f, [1,2,3])
([-0.732050807568877, 0.9999999999999996, 2.7320508075688776],)

I just spent some time trying to write an adjoint for ScaledPlan by replacing it with the right-associative P.scale * (P.p * x) and using rrule_via_ad, but the fundamental issue was that ChainRules was unable to come up with a tangent type for ScaledPlan, for the same reason (mutability, circular references, etc.)

A few thoughts:

  • Really, even if we keep the pinv field around, ScaledPlan ought to not be mutable as the caching can just happen at the level of the inner plan. If this were fixed, it would be easy to write an rrule, but this would be a separate PR that would probably require a lot of careful thought
  • Maybe we can somehow get the ScaledPlan differentiation to work, by adding a custom tangent type for the mutable struct. Don't have much ChainRules knowhow but I can give this a shot
  • I really really don't like that this silently gives incorrect results (although the current plan rules in Zygote do too for real FFTs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe making ScaledPlan immutable isn't so hard... give me a minute:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I assumed it's due to problems with CRTU but my main worry was exactly something like the ScaledPlan case: that it masks differentiation issues.

Copy link
Contributor Author

@gaurav-arya gaurav-arya Jul 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To try to resolve this, I've:

a) Made ScaledPlans immutable in #72 (and done the same for AdjointPlans here)

b) Added a backwards rule for `ScaledPlan that differentiates w.r.t. the plan's scale too

c) Added tests. Because of JuliaDiff/ChainRulesTestUtils.jl#256, this turned out to be rather difficult. I couldn't come up with a way of using ChainRulesTestUtils for properly checking the derivative w.r.t. ScaledPlan without a PR there, and I have to pick my battles, so I ended up coming up with the best cludge I can think to preserve most of the automated testing w.r.t. the other tangents and adding a manual FD test for the plan scale. (In an ideal world I'd have been able to just get rid of the NoTangent: once ChainRulesTestUtils is able to accommodate this case, we should do that here.)

What do you think of the approach?

@devmotion
Copy link
Member

I guess this is still waiting for JuliaLang/julia#78? I think it would be good to also make ChainRulesCore a weak dependency on new Julia versions (see, e.g., how it is done in SpecialFunctions).

@vpuri3
Copy link
Contributor

vpuri3 commented Jun 26, 2023

@gaurav-arya you might want to rebase after 3a3f0e4.

@vpuri3
Copy link
Contributor

vpuri3 commented Jun 26, 2023

Question: what would the projection style be for real-to-real DCTs/ DSTs in FFTW.jl?

@sethaxen sethaxen mentioned this pull request Jun 30, 2023
@devmotion
Copy link
Member

I updated the PR (and fixed a few issues with types and functions that were not available in the extension).

Since in-place plans are not supported (or at least not tested?) currently, maybe a final thing to add would be to check in the ChainRules definitions that y = P * x is not aliased with x, and throw a descriptive error otherwise.

@gaurav-arya
Copy link
Contributor Author

I don't have the time right now to revisit this PR, but if it looks good and would be helpful, please feel free to fix anything remaining and merge. Thanks!

@devmotion
Copy link
Member

@sethaxen since you were interested in differentiation rules as well (#103), I guess it might be valuable if you would have a look at the PR before I go ahead and merge it?

@sethaxen
Copy link
Contributor

sethaxen commented Jul 4, 2023

Sure @devmotion, will take a quick look tonight.

Copy link
Contributor

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of my comments are documentation suggestions, but there are also a few other minor ones.

docs/src/implementations.md Outdated Show resolved Hide resolved
docs/src/implementations.md Outdated Show resolved Hide resolved
docs/src/implementations.md Outdated Show resolved Hide resolved
docs/src/implementations.md Outdated Show resolved Hide resolved
docs/src/implementations.md Outdated Show resolved Hide resolved
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it inconsistent at all that here we use the tangent of the scale part of P but none of the tangent of the wrapped plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm it seems plans are assumed to be constant (AFAICT from the initial version of the PR) but the scaling might change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there's probably never a good reason a user would want (co)tangents for a Plan. In almost every case the scale of a Plan is just a constant that again a user would never want a (co)tangent for, but perhaps there is one user out there who does, so I can see the point in this.

scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if the mul!(y, p, x, a, b) API was supported by AbstractFFTs, because then ChainRules could also define an inplaceable thunk here, and Enzyme rules could avoid an allocation, but maybe outside the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would require the FFT plan to support fused mul! which isn't guaranteed. To create a fallback implementation, the plan must cache y.

cache = get_cache(plan)
copy!(cache, y)
mul!(y, plan, x)
axpby!(b, cache, a, y)

Feels out of scope for this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that mul! has to guarantee allocation-free or fused computations (but maybe I'm wrong). Usually, ! only indicates that some (usually but not necessarily the first) argument is updated in-place but sometimes other arguments are updated as well and/or the update is not allocation-free.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is my understanding that LinearAlgebra.mul! is allocation-free. That is what gives it performance advantage over Base.*. To my knowledge, no mutating LinearAlgebra routine allocates a copy of the base array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is my understanding that LinearAlgebra.mul! is allocation-free.

I quickly checked the Julia repo, and there are a few open issues that show that at least in practice such a guarantee does not exist: https://github.com/JuliaLang/julia/issues/49332 JuliaLang/julia#46865 Arguably these are just bugs but on the other hand the docstring of mul! also does not make any such guarantees.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both cases, the allocation size is independent of the array size indicating that the arrays are not being allocated. Looks like a spurious size tuple allocation to me.

Examples:

julia> versioninfo()
Julia Version 1.9.1
Commit 147bdf428cd (2023-06-07 08:27 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M2

https://github.com/JuliaLang/julia/issues/49332

julia> using LinearAlgebra, BenchmarkTools

julia> A = rand(ComplexF64,4,4,1000,1000);

julia> B = similar(A);

julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));

julia> @btime mul!($b,$a,$a); # 4x4 * 4x4
  311.283 ns (10 allocations: 608 bytes)

julia> A = rand(ComplexF64,128,128,10,10);

julia> B = similar(A);

julia> a,b = (view(B,:,:,1,1),view(A,:,:,1,1));

julia> @btime mul!($b,$a,$a); # 128x128 * 128x128
  170.542 μs (10 allocations: 608 bytes)

JuliaLang/julia#46865

julia> N = 5_000;

julia> A = rand(N, N); B = rand(N, N); C = rand(N, N);

julia> @time mul!(C, A, B, true, true);
  1.729141 seconds (1 allocation: 16 bytes)

julia> @time mul!(C, A, B);
  1.637079 seconds

julia> @time A * B; # allocates N x N array
  1.421422 seconds (2 allocations: 190.735 MiB, 0.13% gc time)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my understanding from skimming through the issues - and why I wrote arguably these could be considered to be bugs. My main point: There are no guarantees in Julia regarding allocation, the language or the JIT-compiler does not enforce any contracts, so it's only possible to document interfaces and trust people to implement them accordingly. But in the case of mul! no such guarantees are documented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it makes sense for AbstractFTTs to ultimately support downstream packages implementing either 3-arg or 5-arg mul!, with each defaulting to the other (yes stackoverflow, but if implementing one of them is required, then no overflow can exist). But I do also think this needn't happen in this PR.

src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return convert(typeof(x), scale) ./ N .* (p.p \ x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could a parentheses be added here to make this easier to understand?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would you like to add them?

@vpuri3
Copy link
Contributor

vpuri3 commented Jul 4, 2023

Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?

@devmotion
Copy link
Member

Now that adjoints are defined at the *(::Plan, ::AbstractArray) level, can the rules for fft, rfft etc be removed?

I'd argue no, they should be kept. Zygote also defined rules for fft etc. and plan_fft, but the IMO the main argument is that fft etc. are used for one-shot computations of FFTs whereas plan_fft etc. are intended for repeated calculations - so downstream packages might want to implement fft etc. in an optimized way knowing that there's no amortization, users might want to rather call fft etc. if they only apply it once, and its the only way for the rules to distinguish between both use cases.

src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
src/definitions.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@devmotion devmotion merged commit 8601a92 into JuliaMath:master Jul 5, 2023
@vpuri3 vpuri3 deleted the adjoint branch July 5, 2023 14:37
vpuri3 added a commit to vpuri3/Zygote.jl that referenced this pull request Jul 5, 2023
The potentially incorrect Zygote rules for FFT (FluxML#899) can be removed now that comprehensive Chain Rules have been added in JuliaMath/AbstractFFTs.jl#67
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants