Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching allocator interface #576

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

Add caching allocator interface #576

wants to merge 21 commits into from

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Dec 15, 2024

Since Julia's GC is not aware of GPU memory, in scenarios with lots of allocations we end up in either OOM situations or in excessively high memory usage.
Even though the program may require only fraction of it.

To help with GPU memory utilizaton in a program with repeating blocks of code, we can wrap those regions in a scope that will utilize caching allocator every time the program enters this scope during the execution.

For example, this is especially useful when training models, where you compute loss, gradients w.r.t. loss and perform in-place parameter update of the model.

cache = GPUArrays.AllocCache(CuArray)
model = ...
for i in 1:1000
    GPUArrays.@enable cache begin
        loss, grads = ...
        update!(optimizer, model, grads)
    end
end

Caching is done on: (ArrayType, current device, eltype, dims[, buffer type]).

Example

In the following example we apply caching allocator at every iteration of the for-loop.
Every iteration requires 8 GiB of gpu memory, without caching allocator
GC wouldn't be able to free arrays in time resulting in higher memory usage.
With caching allocator, memory usage stays at exactly 8 GiB.

After the loop, we free all cached memory if there's any.
Alternatively, it will be freed automatically when cache is collected by GC.

cache = GPUArrays.AllocCache(CuArray)
n = 1024^3
CUDA.@sync for i in 1:1000
    GPUArrays.@enable cache begin
        sin.(CUDA.rand(Float32, n))
    end
end
GPUArrays.unsafe_free!(cache)

Performance impact

Executing GaussianSplatting.jl benchmark (1k training iterations) on RX 7900XTX:

Without caching allocator With caching allocator
GPU memory utilization image image
Time 59.656476 seconds 46.365646 seconds

TODO

  • Support for 1.10.
  • Support bulk-freeing instead of caching.
  • Add PR description.
  • Documentation.
  • Tests.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some high-level design description to the PR?

As I mentioned on Slack, CUDA already has a caching allocator, so I'm not sure if for those back-ends this shouldn't boil down to basically batch-calling unsafe_free! at the end of each iteration, instead of actively caching arrays. Would be good to compare performance, if possible.

src/host/allocations_cache.jl Outdated Show resolved Hide resolved
@pxl-th
Copy link
Member Author

pxl-th commented Dec 16, 2024

Could you add some high-level design description to the PR?

As I mentioned on Slack, CUDA already has a caching allocator, so I'm not sure if for those back-ends this shouldn't boil down to basically batch-calling unsafe_free! at the end of each iteration, instead of actively caching arrays. Would be good to compare performance, if possible.

Yeah, I'm planning to add both detailed PR description and documentation.
And to allow batch-freeing instead of caching the arrays (which can be just an option in the caching allocator).

@pxl-th
Copy link
Member Author

pxl-th commented Dec 18, 2024

@maleadt, I've updated the PR.
Also, I've added tests but they are not enabled right now, because no backend currently has the implementation merged (including JLArrays, because tests use released version of it).
However, they pass locally on my machines.

Let me know what you think.

src/host/allocations_cache.jl Outdated Show resolved Hide resolved
docs/src/interface.md Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
@pxl-th pxl-th force-pushed the pxl-th/cache-alloc branch from 051cd6d to d6a74b0 Compare January 6, 2025 21:00
@pxl-th pxl-th self-assigned this Jan 6, 2025
@pxl-th pxl-th requested a review from maleadt January 6, 2025 21:15
@pxl-th
Copy link
Member Author

pxl-th commented Jan 7, 2025

One difference I've found between Julia 1.10 and Julia 1.11:

  • Julia 1.10:
julia> GPUArrays.AllocCache.@enable CuArray :loop begin
           x1 = CuArray(rand(Float32, 1))
       end
1-element CuArray{Float32, 1, CUDA.DeviceMemory}:
 0.680597

julia> x1
ERROR: UndefVarError: `x1` not defined
  • Julia 1.11:
julia> GPUArrays.AllocCache.@enable CuArray :loop begin
           x1 = CuArray(rand(Float32, 1))
       end
1-element CuArray{Float32, 1, CUDA.DeviceMemory}:
 0.7703809

julia> x1
1-element CuArray{Float32, 1, CUDA.DeviceMemory}:
 0.7703809

Not sure where is it coming from.

@maleadt maleadt force-pushed the pxl-th/cache-alloc branch from bc6dcd7 to ee377ea Compare January 8, 2025 13:54
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runic suggested the following formatting changes.

docs/make.jl Outdated Show resolved Hide resolved
lib/JLArrays/src/JLArrays.jl Outdated Show resolved Hide resolved
lib/JLArrays/src/JLArrays.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
src/host/allocations_cache.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Comment on lines 400 to 402
const JLACacheAllocator = GPUArrays.AllocCache.PerDeviceCacheAllocator(JLArray)

GPUArrays.AllocCache.cache_allocator(::Type{<:JLArray}) = JLACacheAllocator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed, now that you switched to the array type? Isn't all information there for the caller to construct an appropriate allocator cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the internal implementation to retrieve the actual cache for @enable.
E.g. when CUDA calls alloc! we retrieve its allocator cache based on its array type.
Otherwise the user would have to pass the cache itself to the macro, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that we can get rid of the alias, and replace calls to cache_allocator by AllocCache.PerDeviceCacheAllocator(AT). Just trying to minimize the interface to be implemented by back-ends.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not opposed, just a bit confused.
AllocCache.PerDeviceCacheAllocator(AT) is call to ctor, when with cache_allocator we retrieve an instance.
In CUDA.jl we define a global variable that we then retrieve with cache_allocator.
How are we going to access this variable with AllocCache.PerDeviceCacheAllocator(AT) if its a call to ctor?

Or is the point just to rename cache_allocator to AllocCache.PerDeviceCacheAllocator? But then it is ambiguous because ctor has the same method signature.

Comment on lines 129 to 135
"""
invalidate!(AT::Type{AbstractGPUArray}, name::Symbol)

Free all memory held by `name`d cached allocator given array type `AT`.
"""
invalidate!(AT::Type{<:AbstractGPUArray}, name::Symbol) =
invalidate!(cache_allocator(AT), device(AT), name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it expected for users to need this? Why not have them wrap code in multiple @enable blocks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because the cache caches arrays based on their dims, there may be a situation where dims change (e.g. different batch size or the number of parameters of the model change) you need to invalidate the cache, because with new dims the old ones won't be retrieved.

E.g. with GaussianSplatting, where I enable cache for the training step.
But at some point the number of parameters of the model changes so we need to invalidate the cache, because old dims are not used anymore.

Why not have them wrap code in multiple @enable blocks?

IIUC, you mean something like this?

GPUArrays.AllocCache.@enable CuArray :train_step begin
    # some code
end

# some code outside of caching

GPUArrays.AllocCache.@enable CuArray :train_step begin
    # some code
end

But then again when you either no longer need the cache (done training) or the dims change you need to somehow invalidate it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so the cache persists until the user calls invalidate!? I somehow missed that. It seems like a dangerous design to me; if you forget to invalidate! on any path outside of the @enable, memory will leak?

Copy link
Member Author

@pxl-th pxl-th Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although I think this is fine tradeoff (technically you can always call invalidate and free the memory). As a last resort we could invalidate the cache in alloc/retry mechanism, registering a hook. Similar to how we do with fft handle cache

@maleadt
Copy link
Member

maleadt commented Jan 8, 2025

One difference I've found between Julia 1.10 and Julia 1.11

Hmm, that seems problematic. Macros should not introduce scope:

❯ jl +1.10
julia> @time begin
       x1 = []
       end
  0.000002 seconds (1 allocation: 48 bytes)
Any[]

julia> x1
Any[]

@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2025

ScopedValues.jl on Julia 1.10 introduce a scope:

julia> using ScopedValues

julia> x = ScopedValue(1)
ScopedValue{Int64}(1)

julia> @with x => 2 begin
           x2 = x[]
           x3 = 1
       end
1

julia> x2
ERROR: UndefVarError: `x2` not defined

@maleadt
Copy link
Member

maleadt commented Jan 8, 2025

Another fundamental question (sorry for stretching this out): Why do you even care about the array type in the @enable interface? Wouldn't it be better if the user didn't have to worry about this? The cache can be sharded internally so that back-ends can invalidate! only their portion (e.g. from within a retry_reclaim hook), but I don't see why it has to be provided.

Maybe the cache name should be optional as well. It could default to something derived from the current task's name, so that's it's really convenient to do:

AllocCache.@enable begin
  for i in epocs
    ...
  end
end
AllocCache.invalidate!()

Just spitballing here, you probably have a better view regarding it based on your experiments with it already.


Seeing the above written out, I wonder if a wholly different API wouldn't be much more idiomatic, reifing the now implicit stuff like the name of the cache:

cache = AllocCache()
cache() do
  for i in epocs
    ...
  end
end
empty!(cache)

A finalizer could then also empty the cache, avoiding the risk of leaking memory if you forget to empty! it.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runic suggested the following formatting changes.

src/host/alloc_cache.jl Outdated Show resolved Hide resolved
src/host/alloc_cache.jl Outdated Show resolved Hide resolved
src/host/alloc_cache.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runic suggested the following formatting changes.

src/host/alloc_cache.jl Outdated Show resolved Hide resolved
src/host/alloc_cache.jl Outdated Show resolved Hide resolved
src/host/alloc_cache.jl Outdated Show resolved Hide resolved
pxl-th and others added 3 commits January 8, 2025 23:50
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runic suggested the following formatting changes.

src/host/alloc_cache.jl Show resolved Hide resolved
src/host/alloc_cache.jl Show resolved Hide resolved
@pxl-th
Copy link
Member Author

pxl-th commented Jan 8, 2025

Seeing the above written out, I wonder if a wholly different API wouldn't be much more idiomatic, reifing the now implicit stuff like the name of the cache:

@maleadt , I've updated the implementation based on this, see examples in PR description for TL;DR.
It will now also free all memory in the finalizer.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runic suggested the following formatting changes.

Comment on lines +14 to +19
function AllocCache(::Type{T}) where {T <: AbstractGPUArray}
cache = new{T}(
ReentrantLock(),
Dict{UInt64, Vector{T}}(),
Dict{UInt64, Vector{T}}()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function AllocCache(::Type{T}) where {T <: AbstractGPUArray}
cache = new{T}(
ReentrantLock(),
Dict{UInt64, Vector{T}}(),
Dict{UInt64, Vector{T}}()
)
function AllocCache(::Type{T}) where {T <: AbstractGPUArray}
cache = new{T}(
ReentrantLock(),
Dict{UInt64, Vector{T}}(),
Dict{UInt64, Vector{T}}()
)
return finalizer(unsafe_free!, cache)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants