-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
batched_transpose
with multiple batch dimensions
#588
Comments
After some thinking and tinkering, I've concluded that For my use case however, where I use it to define a custom chain rule, I needed to use the inner constructor with all the type parameters like so: # permutation needs to be passed as type parameters directly so the type can be inferred
function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
perm = (2, 1, 3:N...)
PermutedDimsArray{T, N, perm, perm, typeof(A)}(A)
end or else I would get an error: function _batched_transpose(A::AbstractArray{T, N}) where {T, N}
perm = (2, 1, 3:N...)
PermutedDimsArray(A, perm)
end
using Test
@inferred _batched_transpose(rand(4, 5, 6))
# output:
ERROR: return type PermutedDimsArray{Float64, 3, (2, 1, 3), (2, 1, 3), Array{Float64, 3}} does not match inferred return type PermutedDimsArray{Float64, 3, _A, _B, Array{Float64, 3}} where {_A, _B} I suspect this is because of the splat in the regular constructor: function PermutedDimsArray(data::AbstractArray{T,N}, perm) where {T,N}
length(perm) == N || throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
iperm = invperm(perm)
PermutedDimsArray{T,N,(perm...,),(iperm...,),typeof(data)}(data)
end This isn't really related to the issue, but I figured I'd include it for documentation purposes.😄 EDIT: it's probably not the splatting itself, but the fact that the permutation is derived from the type parameter N, so it's essentially a constant. EDIT 2: somewhat expectedly, CUDA doesn't like this, as it ends up wanting to do scalar indexing. |
When an array with multiple batch dimensions needs to be transposed for use in function batched_mul_transpose1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose # call batched_transpose after flattening batch dimensions
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end This would be the same as |
batched_adjoint
batched_transpose
batched_transpose
batched_transpose
with multiple batch dimensions
It's tricky. Perhaps there need to me methods of Or perhaps the reshaping to 3D should be done by a utility function which knows about BatchedAdjoint, not just Xref #391 about other questions about (Also, some regret that we didn't go with an interface like |
Motivation and description
There exists a method for
batched_mul
that reshapes arrays to allow for an arbitrary number of batch dimensions:It would be useful to have support for this with
batched_transpose
andbatched_adjoint
as well.Possible Implementation
The existing code is quite sophisticated and "lazy", so something like this wouldn't fly:
I imagine it would be possible to generalize the code beyond three dimensions though. Indexing methods are currently hard-coded. Things like the strides would also need to be generalized:
Is it better to just use
PermutedDimsArray
?The text was updated successfully, but these errors were encountered: