using LinearAlgebra
using Functors
using ChainRulesCore
using ChainRulesCore: ignore_derivatives
using NeuralAttentionlib: scaled_matmul
using Transformers
using Transformers.HuggingFace
using Transformers.Layers:
    WithArg, WithOptArg, CompositeEmbedding, Embed, SinCosPositionEmbed,
    ApplyEmbed, Branch, Seq2Seq


_cummean_length(dims::Int, x) = reshape(Base.OneTo{Int32}(size(x, dims)), ntuple(i->i==dims ? (:) : 1, Val(ndims(x))))
ChainRulesCore.@non_differentiable _cummean_length(dims, x)

function cummean(x; dims::Int)
    y = cumsum(x; dims)
    return y ./ _cummean_length(dims, x)
end

function l2norm(x; dims=1)
    norm = sum(abs2, x; dims)
    return x ./ sqrt.(norm)
end


function load_tokenizer()
    textenc = HuggingFace.load_tokenizer("t5-small"; local_files_only = true)
    token_map = Dict{String, String}(
        "<extra_id_0>" => "[EOT]",
        "<extra_id_1>" => "[NODE]",
        "<extra_id_2>" => "[EDGE]",
        "<extra_id_3>" => "[TEXT]",
        "<extra_id_4>" => "[EVENTNARRATIVE]",
        "<extra_id_5>" => "[WEBNLG]",
        "<extra_id_6>" => "[GENWIKI]",
        "<extra_id_7>" => "[TEKGEN]",
        "<extra_id_8>" => "[ATOMIC]",
        "<extra_id_9>" => "[AMR]",
        "<extra_id_10>" => "[OPENIE]",
        "<extra_id_11>" => "[DEP]",
        "<extra_id_12>" => "[NER]",
        "<extra_id_13>" => "[DART]",
    )
    @assert all(k->lookup(textenc.vocab, k) != 0, keys(token_map))
    for (k, v) in token_map
        textenc.vocab.list.data[lookup(textenc.vocab, k)] = v
    end
    # push!(textenc.tokenizer.tokenization.patterns, r"[0-9_,\.\\\/\[\]\(\)`!@#$%^&?~\*\"'\+\-:;<>|{}=]")
    return textenc
end

nlog(x) = -log(x)

struct TokenEmbedding{L, T, I, D1, D2, D3, D4, LN} <: Layers.LayerStruct
    label_embed::L
    type_embed::T
    id_embed::I
    pos_proj::D1
    segment_proj::D2
    head_proj::D3
    tail_proj::D4
    ln::LN
end
@functor TokenEmbedding

Layers.argument_names(::TokenEmbedding) = (:label, :type, :id, :prev, :segment, :head, :tail)

function TokenEmbedding(hidden_size, vocab_size)
    label_embed = Embed(hidden_size, vocab_size)
    type_embed = Embed(hidden_size, 2)
    id_embed = SinCosPositionEmbed(nlog, hidden_size, true)
    pos_proj = Layers.Dense(hidden_size, hidden_size)
    segment_proj = Layers.Dense(hidden_size, hidden_size)
    head_proj = Layers.Dense(hidden_size, hidden_size)
    tail_proj = Layers.Dense(hidden_size, hidden_size)
    ln = Layers.LayerNorm(hidden_size)
    return TokenEmbedding(label_embed, type_embed, id_embed, pos_proj, segment_proj, head_proj, tail_proj, ln)
end

type_preserve_id_embed(target, id_embed, x) = id_embed(x)
function type_preserve_id_embed(target::CuArray{Float16}, id_embed, x)
    y = id_embed(x)
    z = CuArray{Float16}(y)
    CUDA.unsafe_free!(y)
    return z
end
ChainRulesCore.@non_differentiable type_preserve_id_embed(args...)

function (te::TokenEmbedding)(nt::NamedTuple)
    label_embedding = te.label_embed(nt.label)
    type_embedding = te.type_embed(nt.type)
    prev_embedding = ignore_derivatives(() -> type_preserve_id_embed(type_embedding, te.id_embed, nt.prev))
    id_embedding = ignore_derivatives(() -> type_preserve_id_embed(type_embedding, te.id_embed, nt.id))
    pos_embedding = ignore_derivatives(() -> id_embedding + prev_embedding)
    segment_embedding = ignore_derivatives(() -> type_preserve_id_embed(type_embedding, te.id_embed, nt.segment))
    head_embedding = ignore_derivatives(() -> type_preserve_id_embed(type_embedding, te.id_embed, nt.head))
    tail_embedding = ignore_derivatives(() -> type_preserve_id_embed(type_embedding, te.id_embed, nt.tail))
    embedding = label_embedding + type_embedding + te.pos_proj(pos_embedding) +
        te.segment_proj(segment_embedding) + te.head_proj(head_embedding) + te.tail_proj(tail_embedding)
    hidden_state = te.ln(embedding)
    s = (; hidden_state, label_embedding, type_embedding, pos_embedding,
         prev_embedding, segment_embedding, head_embedding, tail_embedding, id_embedding)
    return merge(Base.structdiff(nt, NamedTuple{Layers.argument_names(te)}), s)
end

struct IDPreNormResidual{L, N} <: Layers.LayerStruct
    layer::L
    norm::N
end
@functor IDPreNormResidual

Layers.argument_names(l::IDPreNormResidual) = Base.merge_names((:hidden_state, :id), Layers.remove_name(Layers.argument_names(l.layer), :hidden_state))

function (l::IDPreNormResidual)(nt::NamedTuple)
    norm = Layers.apply_on_namedtuple(l.norm, nt)
    y = Layers.apply_on_namedtuple(l.layer, Layers.return_hidden_state(norm, norm.hidden_state + nt.id))
    hidden_state = y.hidden_state + nt.hidden_state
    return Layers.return_hidden_state(y, hidden_state)
end

struct TGStage1{T, D} <: Layers.LayerStruct
    block::T
    label_proj::D
end
@functor TGStage1

Layers.argument_names(::TGStage1) = (:hidden_state, :id_embedding, :attention_mask)

function (t::TGStage1)(nt::NamedTuple)
    attention_mask = ignore_derivatives(() -> get(nt, :attention_mask, nothing))
    hidden_state = t.block(
        (hidden_state = nt.hidden_state,
         id = nt.id_embedding,
         attention_mask = attention_mask)).hidden_state
    label_logits = t.label_proj((; hidden_state)).hidden_state
    return (; hidden_state, label_logits)
end

struct TGStage2{T, D} <: Layers.LayerStruct
    block::T
    head_tail_proj::D
end
@functor TGStage2

Layers.argument_names(::TGStage2) = (:hidden_state, :segment_embedding, :id_embedding, :attention_mask)

function (t::TGStage2)(nt::NamedTuple)
    attention_mask = ignore_derivatives(() -> get(nt, :attention_mask, nothing))
    struct_state = nt.hidden_state + nt.segment_embedding
    hidden_state = t.block((hidden_state = struct_state,
                            attention_mask = attention_mask)).hidden_state
    head_state, tail_state = t.head_tail_proj((; hidden_state)).hidden_state
    id_embeddings = ignore_derivatives(() -> NeuralAttentionlib.collapsed_adjoint(nt.id_embedding))
    head_logits = scaled_matmul(id_embeddings, head_state)
    tail_logits = scaled_matmul(id_embeddings, tail_state)
    return (; head_logits, tail_logits)
end

struct TGBart{E<:TokenEmbedding, T<:Seq2Seq, S1, S2} <: Layers.LayerStruct
    embed::E
    seq2seq::T
    stage1::S1
    stage2::S2
end
@functor TGBart

Layers.argument_names(::TGBart) = (:encoder_input, :decoder_input)

function _tgbart_feature(b::TGBart, nt::NamedTuple)
    enc_embs = b.embed(nt.encoder_input)
    dec_embs = b.embed(nt.decoder_input)
    state = b.seq2seq((encoder_input = enc_embs, decoder_input = dec_embs,))
    hidden_state = state.hidden_state
    return hidden_state, enc_embs, dec_embs, state
end

_past_feature(hidden_state) = @view(hidden_state[:, 1:end-1, :])

function _tgbart_stage1_input(hidden_state, dec_embs, attention_mask)
    return (;
            hidden_state = _past_feature(hidden_state),
            id_embedding = @view(dec_embs.id_embedding[:, 2:end, :]),
            attention_mask)
end

function ChainRulesCore.rrule(config::RuleConfig, ::typeof(_tgbart_stage1_input), hidden_state, dec_embs, attention_mask)
    _hidden_state = _past_feature(hidden_state)
    dhidden_state_buf = fill!(similar(hidden_state), 0)
    id_embedding = @view(dec_embs.id_embedding[:, 2:end, :])
    function pullback(Ȳ)
        d_hidden_state = unthunk(Ȳ.hidden_state)
        if d_hidden_state isa ChainRulesCore.AbstractZero
            dhidden_state = NoTangent()
        else
            _past_feature(dhidden_state_buf) .= d_hidden_state
            dhidden_state = dhidden_state_buf
        end
        return (NoTangent(), dhidden_state, NoTangent(), NoTangent())
    end
    return (; hidden_state = _hidden_state, id_embedding, attention_mask), pullback
end

function _tgbart_stage1(b, hidden_state, dec_embs, attention_mask)
    stage1_input = _tgbart_stage1_input(hidden_state, dec_embs, attention_mask)
    stage1_states = b.stage1(stage1_input)
    label_logits = stage1_states.label_logits
    stage1_hidden_state = stage1_states.hidden_state
    return label_logits, stage1_hidden_state
end

function _tgbart_stage2(b, stage1_hidden_state, dec_embs, attention_mask)
    stage2_states = b.stage2((;
                              hidden_state = stage1_hidden_state,
                              segment_embedding = ignore_derivatives(() ->
                                                                     @view(dec_embs.segment_embedding[:, 2:end, :])),
                              id_embedding = ignore_derivatives(() -> dec_embs.id_embedding),
                              attention_mask,
                              ))
    (; head_logits, tail_logits) = stage2_states
    return head_logits, tail_logits
end

function (b::TGBart)(nt::NamedTuple)
    hidden_state, enc_embs, dec_embs, state = _tgbart_feature(b, nt)
    attention_mask = ignore_derivatives(() -> nt.decoder_input.attention_mask - 1)
    label_logits, stage1_hidden_state = _tgbart_stage1(b, hidden_state, dec_embs, attention_mask)
    head_logits, tail_logits = _tgbart_stage2(b, stage1_hidden_state, dec_embs, attention_mask)
    return (; hidden_state, label_logits, head_logits, tail_logits, attention_mask,
            encoder_output = state.encoder_output, decoder_output = state.decoder_output,
            encoder_embeddings = enc_embs, decoder_embedding = dec_embs)
end

function Base.show(io::IO, m::MIME"text/plain", x::TGBart)
    if get(io, :typeinfo, nothing) === nothing  # e.g. top level in REPL
        Flux._big_show(io, x)
    elseif !get(io, :compact, false)  # e.g. printed inside a Vector, but not a Matrix
        Flux._layer_show(io, x)
    else
        show(io, x)
    end
end

function Base.show(io::IO, m::MIME"text/plain", x::TokenEmbedding)
    if get(io, :typeinfo, nothing) === nothing  # e.g. top level in REPL
        Flux._big_show(io, x)
    elseif !get(io, :compact, false)  # e.g. printed inside a Vector, but not a Matrix
        Flux._layer_show(io, x)
    else
        show(io, x)
    end
end

function get_cfg(vocab_size; head_num = 16, hidden_size = 512, head_size = 32, ffn_size = 2048, kws...)
    N = get(kws, :N, nothing)
    if isnothing(N)
        N = 6
        encoder_layers = get(kws, :encoder_layers, N)
        decoder_layers = get(kws, :decoder_layers, N)
    else
        encoder_layers = decoder_layers = N
    end
    cfg = HuggingFace.HGFConfig{:textgraphbart}(
        (; vocab_size, encoder_layers, decoder_layers,
         d_model = hidden_size, d_head = head_size, num_attention_heads = head_num, d_ffn = ffn_size, kws...))
    return cfg
end

function build_model(vocab_size; float16 = false, kws...)
    cfg = get_cfg(vocab_size; kws...)
    model = HuggingFace.load_model(TGBart, cfg)
    if float16
        model = Flux.paramtype(Float16, model)
    end
    return todevice(model)
end

load_weight!(val::AbstractArray, state::Union{Array, CuArray}) = copyto!(val, state)
load_weight!(val::AbstractArray, state) = load_weight!(val, collect(state))
load_weight!(val::Adjoint, state) = load_weight!(val', state')
load_weight!(model, file::AbstractString) = load_weight!(model, HuggingFace.Pickle.Torch.THload(file))
function load_weight!(model, state_dict)
    cur_state = HuggingFace.get_state_dict(model)
    for (key, val) in cur_state
        state = get(state_dict, key, nothing)
        if isnothing(state)
            @debug "$key not found in state_dict, skip."
        else
            load_weight!(val, state)
        end
    end
    return model
end

load_weight!(model::TGBart, model2::TGBart) = load_weight!(model, HuggingFace.get_state_dict(model2))

function save_weights(file::AbstractString, model)
    HuggingFace.Pickle.Torch.THsave(file, HuggingFace.get_state_dict(model))
end
