# =========================================
# =============== 导入相关库 ===============
# =========================================
import torch
from torch.utils.data import DataLoader
from pyrootutils.pyrootutils import setup_root
import os
import argparse

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import Trainer, TrainingArguments, TrainerCallback


root = setup_root(__file__, ".root", pythonpath=True)

from utils.hessian import hessian_step
from utils.param import collect_grad
from utils.encode.codelength import cal_gradient_length

# =========================================
# =============== 参数列表 =================
# =========================================

parser = argparse.ArgumentParser("参数处理")
parser.add_argument("--param", "-p", type=float, help="Lambda系数")
args = parser.parse_args()


lr = 2e-5
epoch = 10
batch_size = 25
gradient_accumulation_steps = 1
collect_step = 100
grad_save_step = 10
logging_step = 10

model_name = "bert"
method = "norm"  # grad, norm
task_name = "decline"
device = "0"

lamb = 1e3
epsilon = 1e-5
quantizer_step = 16
decline = 0.8

output_dir = f"training_results/{model_name}_{method}"
save_path = f"checkpoints/{task_name}/{model_name}_{method}_{decline}"


# =========================================
# =============== 数据处理 =================
# =========================================

# 读取
dataset = load_dataset("parquet", data_files=[f"{root}/datasets/imdb/train.parquet"])
train_dataset = dataset["train"]

# 分词
tokenizer = AutoTokenizer.from_pretrained(f"{root}/models/{model_name}")
# tokenizer.pad_token = tokenizer.eos_token

train_dataset = train_dataset.map(
    lambda examples: tokenizer(
        examples["text"], max_length=512, truncation=True, padding=False
    ),
    batched=True,
    remove_columns=["text"],
)

# =========================================
# =============== 训练准备 =================
# =========================================

# 设备
os.environ["CUDA_VISIBLE_DEVICES"] = device

# 模型
model = AutoModelForSequenceClassification.from_pretrained(
    f"{root}/models/{model_name}"
)
# model.config.pad_token_id = tokenizer.pad_token_id
model = model.to(torch.bfloat16)

# =========================================
# =============== 模型训练 =================
# =========================================


class GradientCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        global lamb
        lamb *= decline
        print(lamb)
        return control


class GradientTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_step = 0

    def training_step(self, model, inputs, *args):
        self.optimizer.zero_grad()  # 防止外部代码干扰

        # 两次传播
        loss = model(**inputs).loss
        loss.backward()
        del inputs

        # 权重更新
        if method != "norm":
            hessian_step(model, lr, lamb, epsilon)
        self.optimizer.step()

        # 计算梯度码长
        if self.n_step % collect_step == 0:
            grads = collect_grad(model).detach().cpu()
            length = cal_gradient_length(grads, quantizer_step)
            print("EG Length: ", length)
            
            if self.n_step % (collect_step * grad_save_step) == 0:
                idx = self.n_step // (collect_step * grad_save_step)
                torch.save(grads, f"data/{method}/{idx}.pt")
        self.n_step += 1

        # 清空梯度，防止外部代码干扰
        self.optimizer.zero_grad()
        return loss


# 训练参数
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=lr,
    num_train_epochs=epoch,
    per_device_train_batch_size=batch_size,
    lr_scheduler_type="constant",
    logging_strategy="steps",
    logging_steps=logging_step,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_strategy="epoch",
    save_total_limit=1,
    bf16=True,
)


trainer = GradientTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer, padding=True),
    callbacks=[GradientCallback()],
)

# 训练模型
trainer.train()


# 保存模型
# model.save_pretrained(save_path)
# tokenizer.save_pretrained(save_path)
