using Dates
using Flux.Losses
using NeuralAttentionlib: AttenMask
using ChainRulesCore
using Optimisers
using ValSplit

function guess_process_type(textenc)
    return Core.Compiler.return_type(batch_process_text, Tuple{typeof(textenc), Vector{String}})
end

@valsplit task_transform(Val(task::Symbol), e, b) = task
task_transform(::Val{:mix}, e, b) = task_transform((:g2g, :g2t, :t2g, :t2t), e, b)
task_transform(::Val{:short}, e, b) = task_transform((:g2g, :g2t, :t2g), e, b)
task_transform(::Val{:self}, e, b) = task_transform((:g2g, :t2t), e, b)
task_transform(task::Tuple, e, b) = task[rand(1:length(task))]
task_transform(::Val{:roll}, e, b) = (:g2g, :g2t, :t2g, :t2t)[mod1(e, 4)]
task_transform(::Val{:batch_roll}, e, b) = (:g2g, :g2t, :t2g, :t2t)[mod1(e+b-1, 4)]

@valsplit task_process(Val(task::Symbol), e, setting, textenc, ch, samples) = error("Unknown task: $task")

function task_process(::Val{:g2g}, e, setting, textenc, ch, samples)
    real_batch_size = length(samples)
    data_enc = batch_process_graph(
        textenc, samples; full_shuffle = true,
        shuffle_id = setting[:shuffle_enc_id],
        rand_drop_elm = setting[:rand_drop_elm],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    data_dec = batch_process_graph(
        textenc, samples;
        shuffle_id = setting[:shuffle_id],
        require_end = setting[:ends_threshold],
        full_shuffle = setting[:dec_full_shuffle],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    put!(ch, (e, real_batch_size, data_enc, data_dec))
end

function task_process(::Val{:g2t}, e, setting, textenc, ch, samples)
    real_batch_size = length(samples)
    data_enc = batch_process_graph(
        textenc, samples; full_shuffle = true,
        shuffle_id = setting[:shuffle_enc_id],
        rand_drop_elm = setting[:rand_drop_elm],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    data_dec = batch_process_text(
        textenc, samples; require_end = true,
        shuffle_id = setting[:shuffle_id],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    put!(ch, (e, real_batch_size, data_enc, data_dec))
end

function task_process(::Val{:t2g}, e, setting, textenc, ch, samples)
    real_batch_size = length(samples)
    data_enc = batch_process_text(
        textenc, samples;
        shuffle_id = setting[:shuffle_enc_id],
        rand_drop_elm = setting[:rand_drop_elm],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    data_dec = batch_process_graph(
        textenc, samples;
        shuffle_id = setting[:shuffle_id],
        require_end = setting[:ends_threshold],
        full_shuffle = setting[:dec_full_shuffle],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    put!(ch, (e, real_batch_size, data_enc, data_dec))
end

function task_process(::Val{:t2t}, e, setting, textenc, ch, samples)
    real_batch_size = length(samples)
    texts = extract_text.(samples)
    data_enc = batch_process_text(
        textenc, texts#=samples=#;
        shuffle_id = setting[:shuffle_enc_id],
        rand_drop_elm = setting[:rand_drop_elm],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    data_dec = batch_process_text(
        textenc, texts#=samples=#; require_end = true,
        shuffle_id = setting[:shuffle_id],
        remap_max = setting[:remap_max],
        force_trunc = setting[:force_trunc])
    put!(ch, (e, real_batch_size, data_enc, data_dec))
end

function task_process(::Val{:all}, e, setting, textenc, ch, samples)
    task_process(Val{:g2g}(), e, setting, textenc, ch, samples)
    task_process(Val{:g2t}(), e, setting, textenc, ch, samples)
    task_process(Val{:t2g}(), e, setting, textenc, ch, samples)
    task_process(Val{:t2t}(), e, setting, textenc, ch, samples)
end

function create_dataloader(textenc, dataset, epoch, batch_size, task, setting; threads = true)
    T = guess_process_type(textenc)
    ET = Tuple{Int, Int, T, T}
    if threads
        c = Channel{ET}(500; spawn = true) do ch
            all_ids = collect(1:length(dataset))
            for e in 1:epoch
                shuffle!(all_ids)
                batch_ids = collect(enumerate(Iterators.partition(all_ids, batch_size)))
                Base.Threads.@threads for (b, batch_id) in batch_ids
                    samples = dataset[batch_id]
                    _task = task_transform(task, e, b)
                    task_process(_task, e, setting, textenc, ch, samples)
                end
            end
            return nothing
        end
    else
        c = Channel{ET}(500; spawn = true) do ch
            all_ids = collect(1:length(dataset))
            for e in 1:epoch
                shuffle!(all_ids)
                for (b, batch_id) in enumerate(Iterators.partition(all_ids, batch_size))
                    samples = dataset[batch_id]
                    _task = task_transform(task, e, b)
                    task_process(_task, e, setting, textenc, ch, samples)
                end
            end
            return nothing
        end
    end
    return c
end

function prealloc_chn(dataloader0, force_trunc, batch_size; pinned_mem = true)
    if pinned_mem
        enc_preallocs = prealloc(force_trunc, batch_size)
        dec_preallocs = prealloc(force_trunc, batch_size)
        cond = Threads.Condition()
        preallocs = (; cond, enc_preallocs, dec_preallocs)
        dataloader = Channel{eltype(dataloader0)}(0; spawn = true) do ch
            for data in dataloader0
                e, real_batch_size, data_enc, data_dec = data
                @lock cond begin
                    prealloc_enc = toprealloc(enc_preallocs, data_enc)
                    prealloc_dec = toprealloc(dec_preallocs, data_dec)
                    put!(ch, (e, real_batch_size, prealloc_enc, prealloc_dec))
                    wait(cond)
                end
            end
        end
        return preallocs, dataloader
    end
    return nothing, dataloader0
end

function model_forward(m, input, gamma)
    states = m(input)
    losses = compute_loss(input, states)
    (; label_loss, head_loss, tail_loss) = losses
    other_loss = head_loss + tail_loss
    lval = label_loss + gamma * other_loss
    return lval, losses
end

model_forward_backward(model, input, gamma) = Flux.pullback(m -> model_forward(m, input, gamma), model)

function train!(
    model, textenc, opt, dataset; task = :g2g, epoch = 1, batch_size = 16, force_trunc = nothing,
    shuffle_id = 0.3, remap_max = nothing,  shuffle_enc_id = true, rand_drop_elm = false,
    ends_threshold = 6, dec_full_shuffle = 0.05,
    display_size = 2000, update_size = 128, γ = 1,
    pinned_mem = true, threads = true,
)
    gamma = convert(Float32, γ)
    @assert !isnothing(force_trunc)
    setting = (; force_trunc, shuffle_id, remap_max, shuffle_enc_id, rand_drop_elm, ends_threshold, dec_full_shuffle)

    dataloader0 = create_dataloader(textenc, dataset, epoch, batch_size, task, setting; threads)
    preallocs, dataloader = prealloc_chn(dataloader0, force_trunc, batch_size; pinned_mem)

    grads = IdDict{Optimisers.Leaf, Any}()
    update_i::Int = 0
    j::Int = 0
    al::Float32 = zero(Float32)
    als::NTuple{3, Float32} = ntuple(_ -> zero(Float32), Val(3))
    for (i, data) in enumerate(dataloader)
        e, real_batch_size, _data_enc, _data_dec = data
        data_enc = Base.structdiff(_data_enc, NamedTuple{(:pred_mask, :head_label, :tail_label)}) |> todevice
        data_dec = todevice(_data_dec)
        input = (encoder_input = data_enc,
                 decoder_input = merge(data_dec, (cross_attention_mask =
                                                  AttenMask(data_dec.attention_mask, data_enc.attention_mask),)))
        if !isnothing(preallocs)
            CUDA.synchronize()
            @lock preallocs.cond notify(preallocs.cond)
        end
        (_l, _losses), back = model_forward_backward(model, input, gamma)
        l::Float32 = convert(Float32, _l)
        losses = map(x -> convert(Float32, x),
                     _losses)::@NamedTuple{label_loss::Float32, head_loss::Float32, tail_loss::Float32}
        al += l
        als = als .+ values(losses)::NTuple{3, Float32}
        j += real_batch_size
        if isone(mod1(i, display_size))
            al /= Float32(j)
            als = als ./ Float32(j)
            @info "$(now()): (epoch=$e, batch=$i) loss" total=al losses=NamedTuple{keys(losses)}(als)
            j = 0
            al = zero(Float32)
            als = ntuple(_ -> zero(Float32), Val(3))
        end
        (grad_i,) = back((Flux.Zygote.sensitivity(l), nothing))
        maybe_free!(data_enc); maybe_free!(data_dec)
        grads!(grads, opt, model, grad_i)
        update_i += real_batch_size
        if update_i >= update_size
            foreachgrad(Base.Fix1(_scale!, inv(update_i)), grads)
            Optimisers._update!(opt, model; grads, params = IdDict())
            update_i = 0
            foreachgrad(Base.Fix2(fill!, 0), grads)
            GC.gc(false)
        end
    end

    if update_i != 0
        foreachgrad(Base.Fix1(_scale!, inv(update_i)), grads)
        Optimisers._update!(opt, model; grads, params = IdDict())
        update_i = 0
        foreachgrad(maybe_free!, grads)
    end
    CUDA.synchronize()
    GC.gc()
    return nothing
end

function train16!(
    model, model16, textenc, opt, dataset; task = :g2g, epoch = 1, batch_size = 16, force_trunc = nothing,
    shuffle_id = 0.3, remap_max = nothing,  shuffle_enc_id = true, rand_drop_elm = false,
    ends_threshold = 6, dec_full_shuffle = 0.05,
    display_size = 2000, update_size = 128, γ = 1,
    pinned_mem = true, threads = true,
)
    gamma = convert(Float16, γ)
    @assert !isnothing(force_trunc)
    setting = (; force_trunc, shuffle_id, remap_max, shuffle_enc_id, rand_drop_elm, ends_threshold, dec_full_shuffle)

    dataloader0 = create_dataloader(textenc, dataset, epoch, batch_size, task, setting; threads)
    preallocs, dataloader = prealloc_chn(dataloader0, force_trunc, batch_size; pinned_mem)

    grads = IdDict{Optimisers.Leaf, Any}()
    update_i::Int = 0
    j::Int = 0
    al::Float32 = zero(Float32)
    als::NTuple{3, Float32} = ntuple(_ -> zero(Float32), Val(3))
    warm = false
    for (i, data) in enumerate(dataloader)
        e, real_batch_size, _data_enc, _data_dec = data
        data_enc = Base.structdiff(_data_enc, NamedTuple{(:pred_mask, :head_label, :tail_label)}) |> todevice
        data_dec = todevice(_data_dec)
        input = (encoder_input = data_enc,
                 decoder_input = merge(data_dec, (cross_attention_mask =
                                                  AttenMask(data_dec.attention_mask, data_enc.attention_mask),)))
        if !isnothing(preallocs)
            CUDA.synchronize()
            @lock preallocs.cond notify(preallocs.cond)
        end

        if !warm
            _gamma = Float32(γ)
            (_l, _losses), back = Flux.pullback(m -> model_forward(m, input, _gamma), model)
        else
            (_l, _losses), back = Flux.pullback(m -> model_forward(m, input, gamma), model16)
        end
        l::Float32 = convert(Float32, _l)
        losses = map(x -> convert(Float32, x),
                     _losses)::@NamedTuple{label_loss::Float32, head_loss::Float32, tail_loss::Float32}
        al += l
        als = als .+ values(losses)::NTuple{3, Float32}
        j += real_batch_size
        if isone(mod1(i, display_size))
            al /= Float32(j)
            als = als ./ Float32(j)
            @info "$(now()): (epoch=$e, batch=$i) loss" total=al losses=NamedTuple{keys(losses)}(als)
            j = 0
            al = zero(Float32)
            als = ntuple(_ -> zero(Float32), Val(3))
        end
        (grad_i,) = back((Flux.Zygote.sensitivity(Float16(l)), nothing))
        maybe_free!(data_enc); maybe_free!(data_dec)
        grads!(grads, opt, model, grad_i)
        update_i += real_batch_size
        if update_i >= update_size
            foreachgrad(Base.Fix1(_scale!, inv(update_i)), grads)
            Optimisers._update!(opt, model; grads, params = IdDict())
            update_i = 0
            foreachgrad(Base.Fix2(fill!, 0), grads)
            load_weight!(model16, model)
            GC.gc(false)
        end

        if !warm
            warm = true
            CUDA.synchronize()
            GC.gc(true)
            CUDA.reclaim()
            @info "warm up"
        end
    end

    if update_i != 0
        foreachgrad(Base.Fix1(_scale!, inv(update_i)), grads)
        Optimisers._update!(opt, model; grads, params = IdDict())
        update_i = 0
        load_weight!(model16, model)
        foreachgrad(maybe_free!, grads)
    end
    CUDA.synchronize()
    GC.gc()
    return nothing
end

function _scale!(scale, dx)
    dx .*= convert(eltype(dx), scale)
    return nothing
end
