Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First commit to add weighted mean square #140

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,62 @@ function (sl::SquaredLoss{<:AbstractArray{<:Number}})(
T(0.5) * s, p, pu
end

"""
WeightedSquaredLoss(target)

Calculates half of mean weighted squared loss of the target.
"""
struct WeightedSquaredLoss{Y, W<:AbstractVector{Y}} <: AbstractLoss{Y}
y::Y
weights::W
end
(::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w)
WeightedSquaredLoss() = WeightedSquaredLoss(nothing)
target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the target should be sliceable and the loss should be callable on target's result to create a new one.
It's used for slicing/iterating over batches.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the point is that WeightedSquaredLoss(target(wsl)) should be able to run, did I get it right?

Copy link
Contributor

@chriselrod chriselrod May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

$ rg 'target\('
docs/src/examples/custom_loss_layer.md
49:SimpleChains.target(loss::BinaryLogitCrossEntropyLoss) = loss.targets

src/optimize.jl
93:  tgt = view_slice_last(target(loss), f:l)
125:  tgt = target(loss)
177:  tgt = target(loss)
488:  t = target(_chn)
679:  tgt = target(chn)

src/loss.jl
25:target(_) = nothing
26:target(sc::SimpleChain) = target(last(sc.layers))
27:preserve_buffer(l::AbstractLoss) = target(l)
28:StrideArraysCore.object_and_preserve(l::AbstractLoss) = l, target(l)
31:iterate_over_losses(sc) = _iterate_over_losses(target(sc))
40:  align(length(first(target(sl))) * static_sizeof(T)), static_sizeof(T)
42:function _layer_output_size_needs_temp_of_equal_len_as_target(
47:  align(length(target(sl)) * static_sizeof(T)), static_sizeof(T)
66:target(sl::SquaredLoss) = getfield(sl, :y)
69:Base.getindex(sl::SquaredLoss, r) = SquaredLoss(view_slice_last(target(sl), r))
120:target(sl::AbsoluteLoss) = getfield(sl, :y)
127:  AbsoluteLoss(view_slice_last(target(sl), r))
197:target(sl::LogitCrossEntropyLoss) = getfield(sl, :y)
205:  _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
212:  _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
254:  LogitCrossEntropyLoss(view(target(sl), r))
273:  correct_count(Y, target(loss))
283:  ec = correct_count(Y, target(loss))

src/penalty.jl
68:target(c::AbstractPenalty) = target(getchain(c))

Note that we also need things like view_slice_last(target(loss), f:l) to work.

So view_slice_last should be implemented.
Some form of PtrArray(tgt) should also work, but you could define a different function to use there that calls PtrArray by default, as overloading constructors to return something else is generally frowned upon.

Copy link
Author

@marcobonici marcobonici May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things @chriselrod .

So, as I thought target(wsl) needs to give back all the field of the struct. This is needed because, as you pointed out, WeightedSquaredLoss(target(wsl)) need to be working.

So, I have updated the target method

target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w)

Since this is giving back a tuple, I have added a constructor, using the splat operator
WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...)

Do you have any consideration on that? In the meantime, I'll focus on view_slice_last.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I correctly understand, view_slice_last is used to slice the fields of the loss. If so, this could possibly working.

function view_slice_last(target(wsl::WeightedSquaredLoss), r)
    return Tuple(view_slice_last(f, r) for f in target(wsl))
end

I am returning a Tuple assuming that this can work with my constructor I just created.


Base.getindex(wsl::WeightedSquaredLoss, r) = WeightedSquaredLoss(view_slice_last(target(wsl), r))

weighted_squared_loss(chn::SimpleChain, y, w) = add_loss(chn, WeightedSquaredLoss(y, w))

Base.show(io::IO, ::WeightedSquaredLoss) = print(io, "WeightedSquaredLoss")

@inline loss_multiplier(::AbstractLoss, N, ::Type{T}) where {T} = inv(T(N))
@inline loss_multiplier(::WeightedSquaredLoss, N, ::Type{T}) where {T} = T(2) / T(N)

function chain_valgrad!(
_,
arg::AbstractArray{T,D},
layers::Tuple{WeightedSquaredLoss},
p::Ptr,
pu::Ptr{UInt8}
) where {T,D}
y = getfield(getfield(layers, 1), :y)
w = getfield(getfield(layers, 1), :weights)
# invN = T(inv(static_size(arg, D)))
s = zero(T)
@turbo for i ∈ eachindex(arg)
δ = arg[i] - y[i]
arg[i] = δ
s += δ * δ * w[i]
marcobonici marked this conversation as resolved.
Show resolved Hide resolved
end
T(0.5) * s, arg, pu
end
function (sl::WeightedSquaredLoss{<:AbstractArray{<:Number}})(
arg::AbstractArray{T,N},
p,
pu
) where {T,N}
y = getfield(sl, :y)
w = getfield(sl, :weights)
s = zero(T)
@turbo for i ∈ eachindex(arg)
δ = arg[i] - y[i]
s += δ * δ * w[i]
end
# NOTE: we're not dividing by static_size(arg,N)
T(0.5) * s, p, pu
end

"""
AbsoluteLoss

Expand Down