-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaN with custom mask for MultiHeadAttention #572
Comments
Just a thought, maybe these values need clamping in |
I think the unusual thing here is that the mask (size 16×1×1×32) is constant for whole batches, and thus it's trying to set every value to -Inf before the softmax. That's not illegal according to help:
but it is unusual. Are you sure this is what you want, rather than e.g. running a smaller batch of randomly selected items? Slightly shorter reproducer, and then a case with a more orthodox shape mask, I think: julia> struct MaskAttention{A<:MultiHeadAttention, M<:AbstractArray}
att::A
mask::M
end
julia> (m::MaskAttention)(x::AbstractArray) = first(m.att(x; m.mask))
julia> Flux.@layer :expand MaskAttention
julia> x = map(f->rand(Int32.(2:10), rand(8:16)), 1:32);
julia> x = reduce(hcat, rpad.(x, maximum(length.(x)), 1))
16×32 Matrix{Int32}:
7 7 7 6 9 5 2 8 9 7 3 4 … 9 2 6 4 8 2 9 9 4 5 3 3
2 5 3 10 7 7 9 7 9 6 2 3 7 7 7 10 8 7 7 8 7 10 10 7
6 3 5 5 3 4 2 4 9 9 2 7 9 4 8 4 9 6 3 9 5 4 2 3
julia> mask = permutedims(repeat((x .== 1), outer = [1, 1, 1, 1]), (1, 4, 3, 2))
16×1×1×32 BitArray{4}:
[:, :, 1, 1] =
0
0
0
julia> model = MaskAttention(MultiHeadAttention(16), mask)
MaskAttention(
MultiHeadAttention(16; nheads=8), # 1_024 parameters
Bool[0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1;;;; … ;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 1; 1;;;; 0; 0; … ; 0; 1], # 512 parameters
) # Total: 5 arrays, 1_536 parameters, 4.547 KiB.
julia> xx = randn32(16, 16, 32);
julia> model(xx) |> summary
"16×16×32 Array{Float32, 3}"
julia> model(xx) |> sum
NaN32
julia> findall(isnan, model(xx))
1280-element Vector{CartesianIndex{3}}:
CartesianIndex(1, 1, 11)
CartesianIndex(2, 1, 11)
CartesianIndex(3, 1, 11)
CartesianIndex(4, 1, 11)
CartesianIndex(5, 1, 11)
julia> loss, grads = Flux.withgradient(model) do m
sum(abs2, m(xx))
end
(val = NaN32, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), k_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), v_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[NaN NaN … NaN NaN; NaN NaN … NaN NaN; … ; NaN NaN … NaN NaN; NaN NaN … NaN NaN], bias = nothing, σ = nothing)), mask = nothing),))
julia> mask2 = rand(Bool, 16, 16);
julia> model2 = MaskAttention(MultiHeadAttention(16), mask2);
julia> model2(xx) |> sum
15.281105f0
julia> loss, grads = Flux.withgradient(model2) do m
sum(abs2, m(xx))
end
(val = 2290.0107f0, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[-12.446103 -4.3954763 … -23.305235 3.0878963; 17.259354 -6.7767124 … 25.798717 6.8329597; … ; -22.414658 -17.097672 … 93.836235 -9.206725; 29.134241 13.785215 … -41.797527 -23.159046], bias = nothing, σ = nothing), k_proj = (weight = Float32[10.827733 18.880678 … -35.907875 23.034206; 8.75228 23.103594 … 13.421799 -14.956886; … ; -71.6004 -13.324369 … -35.224113 61.402447; 27.011333 124.142815 … -21.26666 -63.877186], bias = nothing, σ = nothing), v_proj = (weight = Float32[34.734734 -24.169163 … 102.391365 -69.705055; -1.5541999 37.55378 … -69.58273 39.782215; … ; 5.4440746 -176.59694 … 171.69466 161.58585; -139.37288 181.93517 … -213.38739 -58.174618], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[-18.867039 0.40507406 … 92.7197 -36.512943; -30.138624 -14.419439 … -92.25858 -63.47702; … ; -89.09866 92.45964 … -212.48007 164.08275; 35.601555 -31.360823 … 128.91348 -104.323494], bias = nothing, σ = nothing)), mask = nothing),)) |
So I do want to vary this mask per batch, because sequences that are recruited by sampling into the next batch vary in length and padding varies. This minimal example is just one batch to show the issue. I tried clamping before softmax and NaNs are gone. The idea is to mask out from attention in encoder the padding tokens. If it's unusual that am I doing something wrong ? I have three different kinds of masks: padding mask in encoder (this mwe), casual mask in decoder and padding mask in loss function which affects only target sequence in decoder. |
Hi,
The background is that in Encoder-Decoder model used for translation from "Attention Is All You Need" I desired to mask-out the padding in sentence passed to Encoder's MultiHeadAttention, but I notice that the computed for the mask
-neginf
based on the logits eltype might cause some issues and lead to NaN.Minimal example is provided here https://github.com/mashu/NaNTracker.jl
The result is
I hope it's not an issue with my understanding of how mask should look like, but to be honest documentation in Flux could use a couple of examples for this particular use-case in addition to just
make_causal_mask
.The text was updated successfully, but these errors were encountered: