using Dates
using Random

using CUDA
using Flux
using Flux.Losses
import Optimisers
using NNlib
using TextGraphBART
using TextEncodeBase
import Transformers
using Transformers.HuggingFace
using Transformers: todevice
using NeuralAttentionlib
using NeuralAttentionlib: AttenMask

using TextGraphBART: train!, train16!

Transformers.enable_gpu(true)

release_threshold = 0
attribute!(memory_pool(device()), CUDA.MEMPOOL_ATTR_RELEASE_THRESHOLD, UInt(release_threshold))
using NNlibCUDA
NNlibCUDA.softmaxalgo() = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE
CUDA.math_mode!(CUDA.FAST_MATH; precision=:Float16)

const DATASET = TextGraphBART.TrainingData(joinpath(@__DIR__, "datasets"))
const trainset = TextGraphBART.DataSet(TextGraphBART.TrainSet, DATASET)

const textenc = TextGraphBART.load_tokenizer()

model = TextGraphBART.build_model(
    length(textenc.vocab);
    d_cross_head = 64, num_cross_attention_heads = 16,
    proj_head_act = "tanh",
)
opt_rule = Optimisers.RAdam(1f-4)
opt = Optimisers.setup(opt_rule, model)

train!(model, textenc, opt, trainset; task = :all, batch_size = 16, epoch = 1, display_size = 16000, update_size = 256, remap_max = 512, shuffle_id = 0.3, γ = 1, force_trunc = 256, rand_drop_elm = 0.15, pinned_mem = true, threads = true)
TextGraphBART.save_weights("model_e1.bin", model)

train!(model, textenc, opt, trainset; task = :all, batch_size = 16, epoch = 2, display_size = 16000, update_size = 128, remap_max = 512, shuffle_id = 0.3, γ = 1, force_trunc = 256, rand_drop_elm = 0.15, pinned_mem = true, threads = true)
TextGraphBART.save_weights("model_e3.bin", model)

train!(model, textenc, opt, trainset; task = :all, batch_size = 16, epoch = 1, display_size = 16000, update_size = 256, remap_max = 512, shuffle_id = 0.3, γ = 1, force_trunc = 256, rand_drop_elm = 0.15, pinned_mem = true, threads = true)
TextGraphBART.save_weights("model_e4.bin", model)

model = Transformers.Layers.testmode(model)
train!(model, textenc, opt, trainset; task = :all, batch_size = 16, epoch = 1, display_size = 16000, update_size = 256, remap_max = 512, shuffle_id = 0.3, γ = 1, force_trunc = 256, rand_drop_elm = 0.15, pinned_mem = true, threads = true)
TextGraphBART.save_weights("model_e5.bin", model)
