using Base.Broadcast: broadcasted
using Statistics
using NeuralAttentionlib: Masks, AbstractSequenceMask
using ChainRulesCore
using Transformers
using Transformers.Static

function shift_decode_loss(logits, trg, trg_mask)
    label = ChainRulesCore.ignore_derivatives(() -> trg[:, 2:end, :])
    return logitcrossentropy(sum, logits, label, trg_mask)
end

function shift_id_decode_loss(logits, trg, trg_mask)
    label = ChainRulesCore.ignore_derivatives(() -> @view(trg[2:end, :]))
    return unsafe_logitcrossentropy(sum, logits, label, trg_mask)
end

function structure_loss(input, states)
    head_loss = shift_id_decode_loss(states.head_logits, input.decoder_input.head_label, input.decoder_input.pred_mask)
    tail_loss = shift_id_decode_loss(states.tail_logits, input.decoder_input.tail_label, input.decoder_input.pred_mask)
    return (; head_loss, tail_loss)
end

function compute_loss(input::NamedTuple, states::NamedTuple)
    label_loss = shift_decode_loss(states.label_logits, input.decoder_input.label, states.attention_mask)
    (; head_loss, tail_loss) = structure_loss(input, states)
    return (; label_loss, head_loss, tail_loss)
end

unsafe_logitcrossentropy(ŷ::AbstractArray, y::AbstractArray{Int32}, m::AbstractSequenceMask) =
    unsafe_logitcrossentropy(mean, ŷ, y, m)
function unsafe_logitcrossentropy(agg::Union{typeof(sum), typeof(mean)},
                                  ŷ::AbstractArray, c::AbstractArray{Int32}, m::AbstractSequenceMask)
    xmax = maximum(ŷ; dims = 1)
    xdiff = Broadcast.instantiate(broadcasted(Base.FastMath.sub_fast, ŷ, xmax))
    sexp = sum(Broadcast.instantiate(broadcasted(Transformers._exp, xdiff)); dims = 1, init = zero(eltype(ŷ)))
    logp = Broadcast.instantiate(broadcasted(Base.FastMath.sub_fast, xdiff, broadcasted(Base.FastMath.log_fast, sexp)))
    refm = Transformers.Refm(m, ŷ)
    losses = sum(Broadcast.instantiate(broadcasted(Transformers._bcg, Ref(logp), c, CartesianIndices(c), refm));
                 dims = Transformers._tn2(m), init = zero(eltype(ŷ)))
    loss = sum(Broadcast.instantiate(broadcasted(Transformers._sdiv, reshape(losses, :), Masks.lengths(m))))
    if agg isa typeof(mean)
        loss /= oftype(loss, length(losses))
    end
    return loss
end

function ChainRulesCore.rrule(::typeof(unsafe_logitcrossentropy), agg::Union{typeof(sum), typeof(mean)},
                              ŷ::AbstractArray, c::AbstractArray{Int32}, m::AbstractSequenceMask)
    xmax = maximum(ŷ; dims = 1)
    xdiff = Broadcast.instantiate(broadcasted(Base.FastMath.sub_fast, ŷ, xmax))
    sexp = sum(Broadcast.instantiate(broadcasted(Transformers._exp, xdiff)); dims = 1, init = zero(eltype(ŷ)))
    logp = Broadcast.instantiate(broadcasted(Base.FastMath.sub_fast, xdiff, broadcasted(Base.FastMath.log_fast, sexp)))
    refm = Transformers.Refm(m, ŷ)
    losses = sum(Broadcast.instantiate(broadcasted(Transformers._bcg, Ref(logp), c, CartesianIndices(c), refm));
                 dims = Transformers._tn2(m), init = zero(eltype(ŷ)))
    ls = Masks.lengths(m)
    loss = sum(Broadcast.instantiate(broadcasted(Transformers._sdiv, reshape(losses, :), ls)))
    scale = oftype(loss, agg isa typeof(mean) ? length(ls) : 1)
    function logitcrossentropy_pullback(Ybar)
        Ȳ = unthunk(Ybar) / scale
        dlosses = reshape(Transformers._sdiv.(Ȳ, ls), (ntuple(one, static(ndims(m)) - static(1))..., length(ls)))
        dlogp = fill!(similar(ŷ), 0)
        mapreduce(identity, Transformers._z, Broadcast.instantiate(broadcasted(
            Transformers.∇_bcg!, Ref(dlogp), Ref(dlosses), c, CartesianIndices(c), refm)); init = 0)
        dy = Transformers.∇logsoftmax_data!(dlogp, logp; dims = 1)
        return (NoTangent(), NoTangent(), dy, NoTangent(), NoTangent())
    end
    return loss, logitcrossentropy_pullback
end
