import os
import json
import argparse

import torch
import wandb
import swanlab

from tqdm import tqdm

from accelerate import Accelerator
from accelerate.utils import set_seed
from safetensors import safe_open
from torch import nn, optim
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

from models.configs.configs import EConfig

from entrypoints.train_drafter.data_utils import (
    list_files,
    AddGaussianNoise,
    AddUniformNoise,
    CustomDataset,
    CoupledDataset,
    DataCollatorWithPadding,
    DataCollatorWithPaddingForCoupled,
)
from models.drafters.inter_cnets_llamagen import interfuse_softmax

torch.backends.cuda.matmul.allow_tf32 = True


def parse_args():
    parser = argparse.ArgumentParser(description='Training drafter')

    # paths and directories
    parser.add_argument("--model", type=str)
    parser.add_argument('--base_path', type=str)
    parser.add_argument('--config_path', type=str)
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--save_dir', type=str)

    # dataset arguments
    parser.add_argument('--coupled', action='store_true', default=False)
    parser.add_argument('--train_data_ratio', type=float, default=0.95)
    parser.add_argument('--data_noise', type=str, default='uniform')
    parser.add_argument('--mean', type=float, default=0.0)
    parser.add_argument('--std', type=float, default=0.2)

    # training arguments
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--bs', type=int, default=4)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)

    parser.add_argument('--num_epochs', type=int, default=20)
    parser.add_argument('--warmup_steps_ratio', type=float, default=0.03)
    parser.add_argument('--is_warmup', action='store_true', default=True)

    parser.add_argument('--p_w', type=float, default=0.1)
    parser.add_argument('--cfg_loss', action='store_true', default=False)
    parser.add_argument('--cfg_scale', type=float, default=3.0)
    parser.add_argument('--embed_upscale', type=float, default=1.0)
    parser.add_argument('--grad_clip', type=float, default=0.5)
    parser.add_argument('--delta', type=float, default=0.5)  # for DRO strategy
    parser.add_argument('--beta', type=float, default=0.01)  # for KL loss
    parser.add_argument('--dro_lambda', type=float, default=0.1)  # for DRO weight

    parser.add_argument('--max_len', type=int, default=4096)
    parser.add_argument('--eval_freq', type=int, default=1)
    parser.add_argument('--save_freq', type=int, default=5)
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--swanlab', action='store_true', default=False)

    return parser


@torch.no_grad()
def top_accuracy(output, target, topk=(1,)):
    # output.shape (bs, num_classes), target.shape (bs, )
    """Computes the accuracy over the k top predictions for the specified values of k"""
    _, pred = output.topk(max(topk), 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k)
    return res


def update_metrics(out_head, target_head, loss_mask, top_3acc_lo, top_3acc_up):
    out_head_lo = out_head[:, :, :target_head.shape[-1]]
    out_head_up = out_head[:, :, target_head.shape[-1]:]
    _, predicted_lo = torch.max(out_head_lo, 2)
    _, predicted_up = torch.max(out_head_up, 2)
    _, target = torch.max(target_head, 2)

    total = loss_mask.sum().item()
    correct_lo = ((predicted_lo == target) * loss_mask.squeeze()).sum().item()
    correct_up = ((predicted_up == target) * loss_mask.squeeze()).sum().item()

    out_head__lo_flat = out_head_lo.reshape(-1, target_head.shape[-1])[loss_mask.reshape(-1) == 1]
    out_head__up_flat = out_head_up.reshape(-1, target_head.shape[-1])[loss_mask.reshape(-1) == 1]
    target_flat = target.reshape(-1)[loss_mask.reshape(-1) == 1]

    topkacc_lo = top_accuracy(out_head__lo_flat, target_flat, (1, 2, 3))
    topkacc_up = top_accuracy(out_head__up_flat, target_flat, (1, 2, 3))

    for top_i in range(len(topkacc_lo)):
        top_3acc_lo[top_i] += topkacc_lo[top_i]
        top_3acc_up[top_i] += topkacc_up[top_i]

    return correct_lo, correct_up, total


def log_metrics(optimizer, ploss_lo, klloss_lo, vloss_lo, ploss_up, klloss_up, vloss_up, loss_lo, loss_up, loss, correct_lo, correct_up, total, top_3acc_lo, top_3acc_up, phase, wandb):
    logdict = {
        f"{phase}/lr": optimizer.param_groups[0]["lr"] if phase == "train" else None,
        f"{phase}/vloss_lo": vloss_lo.item() if vloss_lo is not None else None,
        f"{phase}/ploss_lo": ploss_lo.item() if ploss_lo is not None else None,
        f"{phase}/klloss_lo": klloss_lo.item() if klloss_lo is not None else None,
        f"{phase}/vloss_up": vloss_up.item() if vloss_up is not None else None,
        f"{phase}/ploss_up": ploss_up.item() if ploss_up is not None else None,
        f"{phase}/klloss_up": klloss_up.item() if klloss_up is not None else None,
        f"{phase}/loss_lo": loss_lo.item() if loss_lo is not None else None,
        f"{phase}/loss_up": loss_up.item() if loss_up is not None else None,
        f"{phase}/loss": loss.item(),
        f"{phase}/acc_lo": correct_lo / total,
        f"{phase}/acc_up": correct_up / total,
    }

    for id, (i_lo, i_up) in enumerate(zip(top_3acc_lo, top_3acc_up)):
        logdict[f'{phase}/top_{id + 1}_acc_lo'] = i_lo.item() / total
        logdict[f'{phase}/top_{id + 1}_acc_up'] = i_up.item() / total
    if wandb:
        wandb.log(logdict)


def run_epoch(args, model, data_loader, optimizer, scheduler, criterion, head, Inter_head, accelerator, is_warmup, train_mode=True):
    model.train() if train_mode else model.eval()

    top_3acc_lo = [0 for _ in range(3)]
    top_3acc_up = [0 for _ in range(3)]
    top_3acc_mid = [0 for _ in range(3)]
    correct_lo, correct_up, correct_mid, total = 0, 0, 0, 0
    epoch_loss = 0
    num_batches = 0

    for data in tqdm(data_loader):
        with torch.set_grad_enabled(train_mode):
            if train_mode:
                optimizer.zero_grad()

            # with torch.no_grad():
            #     predict = model(data["hidden_states"], input_ids=data["input_ids"], attention_mask=data["attention_mask"])

            with torch.no_grad():
                target_head = head(data["target"])
                if args.cfg_loss:
                    """
                        Note that target_head[::2] is a conditioned logits and target_head[1::2] is an unconditioned logits.
                        Although the original formula for the CFG is cond + scale * (cond - uncond), we found that
                        the official implementation of Lumina-mGPT uses uncond + scale * (cond - uncond) instead and
                        thus we follow the same implementation. (This is equivalent to cond + (scale-1) * (cond - uncond)).

                        Note that here the size of target_head is half of the original target_head.
                    """

                    target_head = target_head[::2] + args.cfg_scale * (target_head[::2] - target_head[1::2])

                target_p = nn.Softmax(dim=2)(target_head).detach()

            out, out_up, out_lo, out_mid = Inter_head(data["target"])
            if args.cfg_loss:
                out_mid = out_mid[::2] + args.cfg_scale * (out_mid[::2] - out_mid[1::2])
                out_lo = out_lo[::2] + args.cfg_scale * (out_lo[::2] - out_lo[1::2])
                out_up = out_up[::2] + args.cfg_scale * (out_up[::2] - out_up[1::2])
                out = out[::2] + args.cfg_scale * (out[::2] - out[1::2])
            out_p = interfuse_softmax(out)


            out_p_lo = out_p[:, :, :target_head.shape[-1]]
            out_p_up = out_p[:, :, target_head.shape[-1]:]


            EPS = 1e-8
            MAX_LOSS_PER_SAMPLE = 50.0

            out_p_lo = torch.clamp(out_p_lo, min=EPS, max=1.0 - EPS)
            out_p_up = torch.clamp(out_p_up, min=EPS, max=1.0 - EPS)

            out_logp_lo = torch.log(out_p_lo)
            out_logp_up = torch.log(out_p_up)
            loss_mask = data["loss_mask"][:, :, None]
            if args.cfg_loss:
                p_loss_mask = loss_mask[::2]
            else:
                p_loss_mask = loss_mask

            kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
            out_logp_mid = nn.LogSoftmax(dim=2)(out_mid)
            klloss_mid = kl_loss_fn(out_logp_mid, target_p)
            plogp_mid = target_p * out_logp_mid
            ploss_mid = -torch.sum(torch.sum(p_loss_mask * plogp_mid, 2)) / (p_loss_mask.sum() + EPS)
            vloss_mid = torch.sum(torch.mean(p_loss_mask * criterion(out_mid, target_head), 2)) / (p_loss_mask.sum() + EPS)
            loss_mid = vloss_mid + args.p_w * ploss_mid + args.beta * klloss_mid

            plogp_up = target_p * out_logp_up
            ploss_up = -torch.sum(torch.sum(p_loss_mask * plogp_up, 2)) / (p_loss_mask.sum() + EPS)
            vloss_up = torch.sum(torch.mean(p_loss_mask * criterion(out_up, target_head), 2)) / (p_loss_mask.sum() + EPS)
            # klloss_up = torch.sum(target_p * (torch.log(target_p + 1e-10) - torch.log(out_logp_up + 1e-10))) / (target_p.numel())
            # klloss_up = kl_loss_fn(out_logp_up, target_p)
            loss_up = vloss_up + args.p_w * ploss_up

            plogp_lo = target_p * out_logp_lo
            vloss_lo = torch.sum(torch.mean(p_loss_mask * criterion(out_lo, target_head), 2)) / (p_loss_mask.sum() + EPS)
            ploss_lo = -torch.sum(torch.sum(p_loss_mask * plogp_lo, 2), dim=1)
            dro_weights = torch.softmax(ploss_lo.detach(), dim=0)  # shape: [batch_size]
            weighted_ploss_lo = torch.sum(dro_weights * ploss_lo) / (p_loss_mask.sum() + EPS)  # scalar
            # klloss_lo = kl_loss_fn(out_logp_lo, target_p)
            loss_lo = vloss_lo + weighted_ploss_lo


            loss = loss_up + loss_lo + loss_mid

            assert not torch.isnan(out_p_lo).any(), "out_p_lo contains NaN!"
            assert not torch.isinf(ploss_lo).any(), "ploss_lo contains inf!"

            if train_mode:
                accelerator.backward(loss)
                accelerator.clip_grad_value_(Inter_head.parameters(), args.grad_clip)
                optimizer.step()
                if is_warmup:
                    scheduler.step()

        with torch.no_grad():
            if not args.cfg_loss and args.coupled:
                """
                    Even when the CFG loss is not used, we need to use CFG for accuracy calculation in the coupled setting.
                    Note that if the dataset is not coupled, CFG for accuracy calculation is not available.
                """
                target_p = target_p[::2] + args.cfg_scale * (target_p[::2] - target_p[1::2])
                out_p = out_p[::2] + args.cfg_scale * (out_p[::2] - out_p[1::2])
                p_loss_mask = loss_mask[::2]

            correct_batch_lo, correct_batch_up, correct_batch_mid, total_batch = update_metrics_v5(out_p, target_p, out_mid, p_loss_mask, top_3acc_lo, top_3acc_up, top_3acc_mid)
            correct_lo += correct_batch_lo
            correct_up += correct_batch_up
            correct_mid += correct_batch_mid
            total += total_batch

        if accelerator.is_main_process and loss_mask.sum().item() != 0 and train_mode:
            log_metrics_v5(optimizer, weighted_ploss_lo, None, vloss_lo, ploss_up, loss_mid, vloss_up, loss_lo, loss_up, loss, correct_lo, correct_up, correct_mid, total, top_3acc_lo, top_3acc_up, top_3acc_mid, "train", args.wandb)

        epoch_loss += loss.item()
        num_batches += 1

    correct_lo = torch.tensor(correct_lo, dtype=torch.float32).to(accelerator.device)
    correct_up = torch.tensor(correct_up, dtype=torch.float32).to(accelerator.device)
    correct_mid = torch.tensor(correct_mid, dtype=torch.float32).to(accelerator.device)
    total = torch.tensor(total, dtype=torch.float32).to(accelerator.device)
    epoch_loss = torch.tensor(epoch_loss, dtype=torch.float32).to(accelerator.device)

    correct_lo, correct_up, correct_mid, total, epoch_loss = accelerator.gather_for_metrics((correct_lo, correct_up, correct_mid, total, epoch_loss))
    correct_lo = correct_lo.sum().item()
    correct_up = correct_up.sum().item()
    correct_mid = correct_mid.sum().item()
    total = total.sum().item()
    epoch_loss = epoch_loss.mean()

    top_3acc_lo = [accelerator.gather_for_metrics(torch.tensor(acc_lo, dtype=torch.float32).to(accelerator.device)).sum() for
                acc_lo in top_3acc_lo]
    top_3acc_up = [accelerator.gather_for_metrics(torch.tensor(acc_up, dtype=torch.float32).to(accelerator.device)).sum() for
                acc_up in top_3acc_up]
    top_3acc_mid = [accelerator.gather_for_metrics(torch.tensor(acc_mid, dtype=torch.float32).to(accelerator.device)).sum() for
                acc_mid in top_3acc_mid]

    return epoch_loss / num_batches, correct_lo, correct_up, correct_mid, total, top_3acc_lo, top_3acc_up, top_3acc_mid


def run_train_drafter(args):
    # if args.cfg_loss and not args.coupled:
    #     raise ValueError("--cfg_loss can not be activated without --coupled.")

    set_seed(0)
    accelerator = Accelerator(
        mixed_precision='bf16',
        gradient_accumulation_steps=args.gradient_accumulation_steps, )

    if accelerator.is_main_process:
        if args.wandb:
            wandb.login(key="")
        elif args.swanlab:
            swanlab.login(api_key="", save=True)

        run_name = f"{args.model}_lr{args.lr}_p_w{args.p_w}_bsz{args.bs}_gradacc_{args.gradient_accumulation_steps}"
        run_name += f"_epochs{args.num_epochs}"
        if args.coupled:
            run_name += "_coupled"
        if args.cfg_loss:
            run_name += f"_cfgloss_cfgscale_{args.cfg_scale}"
        if args.embed_upscale > 1.0:
            run_name += f"_embed_upscale_{args.embed_upscale}"
        run_name += "_mscoco2017train30k"
        if args.wandb:
            wandb.init(project="Inter_Head_Training", name=run_name, config=args)
        elif args.swanlab:
            swanlab.init(project="Inter_Head_Training", experiment_name=run_name, config=args)
    if args.model == "lumina_mgpt":
        from models.configs.configuration_lumina_mgpt import ChameleonConfig
        from models.drafters.cnets_lumina_mgpt import Model
        base_config = ChameleonConfig.from_pretrained(args.base_path)
    elif args.model == "anole":
        from models.configs.configuration_anole import ChameleonConfig
        from models.drafters.Inter_cnets_anole import Model, InterHead
        base_config = ChameleonConfig.from_pretrained(args.base_path)
    elif "llamagen" in args.model:
        from transformers import AutoConfig
        from models.drafters.inter_cnets_llamagen import Model, InterHead
        base_config = AutoConfig.from_pretrained(args.base_path)
    else:
        raise ValueError("Invalid model name.")

    ### LOAD `lm_head` ########################################################################
    try:
        with open(os.path.join(args.base_path, "model.safetensors.index.json"), "r") as f:
            index_json = json.loads(f.read())
            head_path = index_json["weight_map"]["lm_head.weight"]
        with safe_open(os.path.join(args.base_path, head_path),
                       framework="pt",
                       device="cpu") as f:
            tensor_slice = f.get_slice("lm_head.weight")
            _, hidden_dim = tensor_slice.get_shape()
            tensor = tensor_slice[:, :hidden_dim].float()
    except:
        try:
            head_path = "model.safetensors"
            with safe_open(os.path.join(args.base_path, head_path),
                           framework="pt",
                           device="cpu") as f:
                tensor_slice = f.get_slice("lm_head.weight")
                vocab_size, hidden_dim = tensor_slice.get_shape()
                tensor = tensor_slice[:, :hidden_dim].float()
        except:
            head_path = "pytorch_model.bin"
            weights = torch.load(os.path.join(args.base_path, head_path), weights_only=True)
            tensor = weights["lm_head.weight"].float()

    Inter_head = InterHead(base_config.hidden_size, base_config.vocab_size)
    # Inter_head.load_freeze_head(tensor)

    ###########################################################################################

    if args.data_noise == "uniform":
        aug = AddUniformNoise(std=args.std)
    elif args.data_noise == "gaussian":
        aug = AddGaussianNoise(mean=args.mean, std=args.std)
    else:
        aug = None

    data_path = list_files(args.data_dir)

    train_data_path = data_path[:int(len(data_path) * args.train_data_ratio)]
    test_data_path = data_path[int(len(data_path) * args.train_data_ratio):]


    train_dataset = CustomDataset(train_data_path, max_len=args.max_len, transform=aug, model=args.model)
    test_dataset = CustomDataset(test_data_path, max_len=args.max_len, model=args.model)

    train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
                              collate_fn=DataCollatorWithPadding(), num_workers=0,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.bs, shuffle=False,
                             collate_fn=DataCollatorWithPadding(), num_workers=0, pin_memory=True)

    if accelerator.is_main_process:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

    ### LOAD `model` and original `head` ########################################################################
    config = EConfig.from_pretrained(args.config_path)
    model = Model(config, load_emb=True, path=args.base_path)
    model.eval()

    for param in model.parameters():
        param.requires_grad = False

    head = torch.nn.Linear(base_config.hidden_size, base_config.vocab_size, bias=False)

    head.weight.data = tensor
    head.eval()

    for param in head.parameters():
        param.requires_grad = False
    ###########################################################################################

    criterion = nn.SmoothL1Loss(reduction="none")
    optimizer = optim.AdamW(Inter_head.parameters(), lr=args.lr, betas=(0.9, 0.95))

    if args.is_warmup:
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=args.warmup_steps_ratio * len(train_loader),
                                                    num_training_steps=args.num_epochs * len(train_loader))

        model, head, Inter_head, optimizer, train_loader, test_loader, scheduler = accelerator.prepare(
            model, head, Inter_head, optimizer, train_loader, test_loader, scheduler
        )
    else:
        model, head, Inter_head, optimizer, train_loader, test_loader = accelerator.prepare(
            model, head, Inter_head, optimizer, train_loader, test_loader
        )
    if args.wandb:
        args.wandb = wandb
    elif args.swanlab:
        args.wandb = swanlab

    for epoch in range(args.num_epochs):
        epoch_loss, epoch_correct_lo, epoch_crrect_up, epoch_correct_mid, epoch_total, epoch_top3_lo, epoch_top3_up, epoch_top3_mid \
            = run_epoch(args, model, train_loader, optimizer, scheduler, criterion, head, Inter_head, accelerator,
                        args.is_warmup, train_mode=True)

        if accelerator.is_main_process:
            log_metrics(optimizer, None, None, None, None, None, None, None, None, epoch_loss, epoch_correct_lo, epoch_crrect_up, epoch_correct_mid, epoch_total, epoch_top3_lo, epoch_top3_up, epoch_top3_mid, "epoch", args.wandb)


        if (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.num_epochs:
            test_loss, test_correct_lo, test_correct_up, test_correct_mid, test_total, test_top3_lo, test_top3_up, test_top3_mid \
                = run_epoch(args, model, test_loader, optimizer, scheduler, criterion, head, Inter_head, accelerator, args.is_warmup, train_mode=False)

            if accelerator.is_main_process:
                log_metrics(optimizer, None, None, None, None, None, None, None, None, test_loss, test_correct_lo, test_correct_up, test_correct_mid, test_total, test_top3_lo, test_top3_up, test_top3_mid, "test", args.wandb)

        if (epoch + 1) % args.save_freq == 0 or (epoch + 1) == args.num_epochs:
            if accelerator.is_main_process:
                accelerator.save_state(output_dir=f"{args.save_dir}/{run_name}/state_{epoch + 1}")


if __name__ == "__main__":
    parser = parse_args()
    args = parser.parse_args()

    run_train_drafter(args)
