using Dates
using Random

using CUDA
using Flux
using Flux.Losses
import Optimisers
using NNlib
using TextGraphBART
using TextEncodeBase
import Transformers
using Transformers.HuggingFace
using Transformers: todevice
using NeuralAttentionlib
using NeuralAttentionlib: AttenMask

using TextGraphBART: inference_g2t, inference_t2d, inference_g2d

Transformers.enable_gpu(true)

release_threshold = 0
attribute!(memory_pool(device()), CUDA.MEMPOOL_ATTR_RELEASE_THRESHOLD, UInt(release_threshold))
using NNlibCUDA
NNlibCUDA.softmaxalgo() = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)

const DATASET = TextGraphBART.EventNarrative(joinpath(@__DIR__, "datasets", "eventnarrative"))
const trainset = TextGraphBART.DataSet(TextGraphBART.TrainSet, DATASET)
const devset = TextGraphBART.DataSet(TextGraphBART.DevSet, DATASET)
const testset = TextGraphBART.DataSet(TextGraphBART.TestSet, DATASET)

const textenc = TextGraphBART.load_tokenizer()

model = TextGraphBART.build_model(
    length(textenc.vocab);
    d_cross_head = 64, num_cross_attention_heads = 16,
    proj_head_act = "tanh",
)
opt_rule = Optimisers.Lion(1f-5)
opt = Optimisers.setup(opt_rule, model)

function evaluate(textenc, model, dataset; max_length = 256)
    len = length(dataset)
    results = Vector{String}(undef, len)
    for i = 1:len
        sample = dataset[i]
        pred  = inference_g2t(textenc, model, sample; max_length, debug = false)
        results[i] = pred
    end
    return results
end

TextGraphBART.load_weight!(model, "models/model_e5.bin")
model = Transformers.Layers.testmode(model)

TextGraphBART.train!(model, textenc, opt, trainset; task = :g2t, batch_size = 16, epoch = 20, display_size = 400, update_size = 128, remap_max = 256, shuffle_id = false, γ = 1, force_trunc = 256, rand_drop_elm = false, pinned_mem = true, threads = true, shuffle_enc_id = false, ends_threshold = 0, dec_full_shuffle = false)
TextGraphBART.save_weights("eventnarrative_ft.bin", model)

