Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: vfmaddsub and friends for complex matmul #150

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/complex_matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,133 @@ for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error
_C
end

function _matmul_v2!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
# C, A, B = map(real_rep, (_C, _A, _B))
C = reinterpret(T, _C)
A = reinterpret(T, _A)
B = real_rep(_B)

η_bool = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
θ_bool = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
# ηθ = η*θ

signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...)
if !η_bool & !θ_bool
cmatmul_ab(C, A, B)

Copy link
Collaborator

@chriselrod chriselrod Aug 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end

maybe?
I just glanced at the code to see if there was any obvious cause to the error.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is all still very WIP. I'll work on this again tomorrow (hard maybe), just wanted to get this pushed since was switching around on machines.

_C
end

function cmatmul_ab!(C, A, B)
@tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1)
Cmn = zero(T)
for k ∈ indices((A, B), (2, 2))
Amk = A[m,k]
Aperm = vpermilps177(Amk)

# A B
Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn))
# A^* B
# Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn))
# A B^*
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn))
# A^* B^*
# Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn))

# Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
# Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
C[m, n] = Cmn
end
end

function cmatmul_astarb()

signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...)

@tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1)
Cmn = zero(T)
for k ∈ indices((A, B), (2, 2))
Amk = A[m,k]
Aperm = vpermilps177(Amk)

# A B
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn))
# A^* B
Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn))
# A B^*
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn))
# A^* B^*
# Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn))

# Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
# Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
C[m, n] = Cmn
end
end

function cmatmul_abstar()
@tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1)
Cmn = zero(T)
for k ∈ indices((A, B), (2, 2))
Amk = A[m,k]
Aperm = vpermilps177(Amk)

# TODO: I don't yet know how to pick the correct branch
# based on η and θ.
# A B
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn))
# A^* B
# Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn))
# A B^*
Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn))
# A^* B^*
# Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn))

# Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
# Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
C[m, n] = Cmn
end
end

function cmatmul_astarbstar()

signs = Vec(ntuple(x -> ifelse(iseven(x), -one(T), one(T)), pick_vector_width(Float64))...)

@tturbo vectorize=2 for n ∈ indices((C, B), (2,3)), m ∈ indices((C, A), 1)
Cmn = zero(T)
for k ∈ indices((A, B), (2, 2))
Amk = A[m,k]
Aperm = vpermilps177(Amk)

# A B
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(Aperm, B[2,k,n], Cmn))
# A^* B
# Cmn = signs * vfmsubadd(Amk, B[1,k,n], vfmadd(Aperm, B[2,k,n], Cmn))
# A B^*
# Cmn = vfmaddsub(Amk, B[1,k,n], vfmaddsub(-Aperm, B[2,k,n], Cmn))
# A^* B^*
Cmn = signs * vfmaddsub(Amk, B[1,k,n], vfmsub(Aperm, B[2,k,n], Cmn))

# Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
# Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
C[m, n] = Cmn
end
end

@inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, B = map(real_rep, (_C, _B))
Expand Down