using Dates
using Random

using CUDA
using Flux
using Flux.Losses
import Optimisers
using NNlib
using TextGraphBART
using TextEncodeBase
import Transformers
using Transformers.HuggingFace
using Transformers: todevice
using NeuralAttentionlib
using NeuralAttentionlib: AttenMask

using TextGraphBART: inference_g2t, inference_t2d, inference_g2d

Transformers.enable_gpu(true)

# release_threshold = 0
# attribute!(memory_pool(device()), CUDA.MEMPOOL_ATTR_RELEASE_THRESHOLD, UInt(release_threshold))
using NNlibCUDA
NNlibCUDA.softmaxalgo() = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)

const DATASET = TextGraphBART.WebNLG2020(joinpath(@__DIR__, "datasets", "webnlg-dataset", "release_v3.0", "en"))
const trainset = TextGraphBART.DataSet(TextGraphBART.TrainSet, DATASET)
const devset = TextGraphBART.DataSet(TextGraphBART.DevSet, DATASET)

const textenc = TextGraphBART.load_tokenizer()

push!(textenc.tokenizer.tokenization.patterns, r"[[:punct:]|\d]")
push!(textenc.tokenizer.tokenization.patterns, r"▁")

model = TextGraphBART.build_model(
    length(textenc.vocab);
    d_cross_head = 64, num_cross_attention_heads = 16,
    proj_head_act = "tanh",
)
opt_rule = Optimisers.Adam(1f-4)
opt = Optimisers.setup(opt_rule, model)

function evaluate(textenc, model, dataset)
    len = length(dataset)
    results = Vector{@NamedTuple{category::String, eid::String,
                                 graph::@NamedTuple{
                                     label::Vector{String}, type::Vector{Int}, id::Vector{Int},
                                     prev::Vector{Int}, segment::Vector{Int}, head::Vector{Int}, tail::Vector{Int}
                                 }}}(undef, len)
    for i = 1:len
        sample = dataset[i].sample
        text = first(sample.texts)
        (; category, eid) = sample
        pred = inference_t2d(textenc, model, text; domain_token = "[WEBNLG]", debug = false,
                             max_component = 22, max_length = 256)
        results[i] = (; category, eid, graph = pred)
    end
    return results
end

function tevaluate(textenc, model, dataset; max_length = 256)
    len = length(dataset)
    results = Vector{String}(undef, len)
    for i = 1:len
        sample = dataset[i]
        pred  = inference_g2t(textenc, model, sample; max_length, debug = false)
        results[i] = pred
    end
    return results
end

TextGraphBART.load_weight!(model, "model_e5.bin")
model = Transformers.Layers.testmode(model);

TextGraphBART.train!(model, textenc, opt, trainset; task = :t2g, batch_size = 16, epoch = 100, display_size = 400, update_size = 128, remap_max = 256, shuffle_id = false, γ = 1, force_trunc = 128, rand_drop_elm = false, pinned_mem = false, threads = true, shuffle_enc_id = false, ends_threshold = 0, dec_full_shuffle = false)
TextGraphBART.save_weights("webnlg_ft.bin", model)
