# =========================================
# =============== 导入相关库 ===============
# =========================================
import torch
from pyrootutils.pyrootutils import setup_root
import os

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import Trainer, TrainingArguments, TrainerCallback


root = setup_root(__file__, ".root", pythonpath=True)

from utils.hessian import hessian_step
from utils.param import collect_grad
from utils.encode.codelength import cal_gradient_length

# =========================================
# =============== 参数列表 =================
# =========================================
lr = 2e-5
epoch = 20
batch_size = 32
accumulate_steps = 1
collect_step = 1000
logging_step = 20

model_name = "deepseek"
method = "norm"  # grad, norm
device = "0,1"
print(method, model_name)

lamb = 1e5
epsilon = 1e-5
quantizer_step = 16
decline = 0.8

output_dir = f"training_results/{model_name}_{method}"
output_path = f"outputs/{model_name}_{method}.txt"
save_path = f"checkpoints/{model_name}_{method}"


# =========================================
# =============== 数据处理 =================
# =========================================

# 读取
dataset = load_dataset("json", data_files=[f"{root}/datasets/squad/train.json"])
train_dataset = dataset["train"]

# 分词
tokenizer = AutoTokenizer.from_pretrained(f"{root}/models/{model_name}")
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"


# =========================================
# =============== 训练准备 =================
# =========================================

# 设备
os.environ["CUDA_VISIBLE_DEVICES"] = device

# 模型
model = AutoModelForCausalLM.from_pretrained(
    f"{root}/models/{model_name}", device_map="auto"
)
# model.config.pad_token_id = tokenizer.pad_token_id
model = model.to(torch.bfloat16)

# =========================================
# =============== 模型训练 =================
# =========================================


class GradientCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        global lamb
        lamb *= decline
        print(lamb)
        return control


class GradientTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_step = 0

    def training_step(self, model, inputs, *args):
        self.optimizer.zero_grad()  # 防止外部代码干扰
        # 两次传播
        loss = model(**inputs).loss.mean()
        loss.backward()
        del inputs

        # 权重更新
        if method != "norm":
            hessian_step(model, lr, lamb, epsilon)
        self.optimizer.step()

        # 计算梯度码长
        if self.n_step % collect_step == 0:
            grads = collect_grad(model).detach().cpu()
            length = cal_gradient_length(grads, quantizer_step)
            print("EG Length: ", length)
        self.n_step += 1

        # 清空梯度，防止外部代码干扰
        self.optimizer.zero_grad()
        return loss


# 训练参数
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=lr,
    num_train_epochs=epoch,
    per_device_train_batch_size=batch_size,
    lr_scheduler_type="constant",
    logging_strategy="steps",
    logging_steps=logging_step,
    gradient_accumulation_steps=accumulate_steps,
    save_strategy="epoch",
    save_total_limit=1,
)

trainer = GradientTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[GradientCallback()],
)
trainer.train()


# 训练
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
