import os
import sys
import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from pyrootutils.pyrootutils import setup_root


from transformers import (
    Trainer,
    TrainingArguments,
    TrainerCallback,
    DataCollatorWithPadding,
    BertForSequenceClassification,
    AutoTokenizer,
)

from datasets import load_dataset

root = setup_root(__file__, ".root", pythonpath=True)

from utils.hessian import hessian_step
from utils.param import collect_grad
from utils.encode.codelength import cal_gradient_length

os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;7.0;7.5;8.0;8.6"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

from ddp_hook import ddp_eg_coding, EGHookState, _noop
from test_llm_models import load_model

lr = 2e-5
epoch = 10
batch_size = 25
gradient_accumulation_steps = 1
collect_step = 100
grad_save_step = 10
logging_step = 10

model_name = "bert"
method = "grad"  # grad, norm
task_name = "decline"

lamb = 1e3
epsilon = 1e-5
quantizer_step = 32
decline = 0.8

RUN_ID = str(int(time.time()))
os.makedirs(f"./runs/{RUN_ID}/", exist_ok=True)
LOG_FILE = open(f"./runs/{RUN_ID}/log.txt", "w")

output_dir = f"./runs/{model_name}_{method}/"
save_path = f"{root}/checkpoints/distribute/{method}"


WORLD_SIZE = 2

MODEL = "bert"


def log(msg: str):
    """
    Print and write to log file.
    """
    print(msg)
    LOG_FILE.write(msg + "\n")
    LOG_FILE.flush()


log(f"Model: {MODEL}")


hook_state = EGHookState()


class GradientCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        global lamb
        lamb *= decline
        return control


class EGTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_step = 0

    def _wrap_model(self, model, *args, **kwargs):
        model = super()._wrap_model(model, *args, **kwargs)

        # Register DDP hook
        if not isinstance(model, DDP):
            model = DDP(model)
        model.register_comm_hook(state=hook_state, hook=ddp_eg_coding)

        return model

    def training_step(self, model, inputs, *args):
        self.optimizer.zero_grad()  # 防止外部代码干扰

        # 两次传播
        loss = model(**inputs).loss
        loss.backward()
        del inputs

        # 权重更新
        if method != "norm":
            hessian_step(model, lr, lamb, epsilon)
        self.optimizer.step()

        # 计算梯度码长
        if self.n_step % collect_step == 0:
            grads = collect_grad(model).detach().cpu()
            length = cal_gradient_length(grads, quantizer_step)
            print("EG Length: ", length, self.args.local_rank)
        self.n_step += 1

        # 清空梯度，防止外部代码干扰
        self.optimizer.zero_grad()
        return loss


def train(rank):
    # Set rank to "0" because I only have one GPU.
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(WORLD_SIZE)

    dist.init_process_group("gloo", rank=rank, world_size=WORLD_SIZE)

    # model, tokenizer, train_dataset = load_model(MODEL)
    tokenizer = AutoTokenizer.from_pretrained(f"{root}/models/bert")
    dataset = load_dataset(
        "parquet", data_files=[f"{root}/datasets/imdb/train.parquet"]
    )["train"]
    train_dataset = dataset.map(
        lambda examples: tokenizer(
            examples["text"], max_length=512, truncation=True, padding=False
        ),
        batched=True,
        remove_columns=["text"],
    )

    model = BertForSequenceClassification.from_pretrained(f"{root}/models/bert")
    model = model.to(torch.bfloat16)

    log(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

    print("Training.")
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=lr,
        num_train_epochs=epoch,
        per_device_train_batch_size=batch_size,
        lr_scheduler_type="constant",
        logging_strategy="steps",
        logging_steps=logging_step,
        gradient_accumulation_steps=gradient_accumulation_steps,
        save_strategy="epoch",
        save_total_limit=1,
        bf16=True,
        ddp_backend="gloo",
    )
    trainer = EGTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer, padding=True),
        callbacks=[GradientCallback()],
    )

    time_start = time.time()
    trainer.train()
    elapse = time.time() - time_start
    if rank == 0:
        log(f"Training time: {elapse:.2f} seconds")

    if rank == 0:
        log(f"DDP Hook calls: {hook_state.calls}")
        log(f"Total params transferred: {hook_state.params}")
        log(f"Total bytes transferred: {hook_state.bytes}")
        log(f"Profiling: {hook_state.profiling[:7]}")

    print(f"Rank {rank} finished.")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)


def main():
    mp.spawn(
        train,
        args=(),
        nprocs=WORLD_SIZE,
        join=True,
    )

    LOG_FILE.close()


if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    main()
