import utils
import models
from datasets import batch_chunk_image, get_loader
from diffusion import DiffusionUtils
from lr_schedulers import get_schedule_fn

import torch
import torch.nn.functional as F
import torch.optim as optim
from scipy.stats import kendalltau
import numpy as np
import wandb
from loguru import logger
from tqdm import tqdm

def save_checkpoint(config, epoch, model, optimizer, scheduler, ckpt_dir="./saved_models"):
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
    }, f"{ckpt_dir}/ckpt_{config.train.run_name}.pth")

def load_checkpoint(config, model, optimizer, scheduler, ckpt_dir="./saved_models"):
    ckpt = torch.load(f"{ckpt_dir}/ckpt_{config.train.run_name}.pth")
    model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    return ckpt["epoch"], model, optimizer, scheduler

def init_model(config):
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    d_out_adjust = "0"
    if config.train.diffusion.transition == "swap" and config.train.diffusion.reverse == "original":
        d_out_adjust = "1"
    elif config.train.diffusion.reverse == "generalized_PL":
        d_out_adjust = "square"

    use_pos_enc = True
    
    model = models.ReverseDiffusion(
        config.dataset, config.CNN.in_channels, config.num_pieces, config.image_size,
        config.CNN.hidden_channels1, config.CNN.kernel_size1, config.CNN.stride1, config.CNN.padding1,
        config.CNN.hidden_channels2, config.CNN.kernel_size2, config.CNN.stride2, config.CNN.padding2, config.num_digits,
        config.transformer.embd_dim, config.transformer.nhead, config.transformer.d_hid, config.transformer.n_layers, config.transformer.dropout,
        d_out_adjust, use_pos_enc
    ).to(device)

    return model

def init_diffusion_utils(config):
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    diffusion_utils = DiffusionUtils(
        config.train.diffusion.num_timesteps,
        config.train.sample_N,
        config.train.diffusion.transition,
        config.train.diffusion.latent,
        config.train.reinforce_N,
        config.train.reinforce_ema_rate,
        config.train.entropy_reg_rate,
        config.train.diffusion.reverse,
        config.train.diffusion.reverse_steps,
        config.train.loss,
        config.beam_size,
    ).to(device)

    return diffusion_utils

def train_diffusion(diffusion_utils, reverse_model, optimizer, scheduler, train_loader, config, start_epoch, ckpt_dir):
    n = config.num_pieces
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    diffusion_utils.train()

    for epoch in tqdm(range(start_epoch, config.train.epochs)):
        # logger.info(f"Epoch {epoch}:")
        reverse_model.train()

        for i, data in enumerate(train_loader):
            if config.dataset == "unscramble-CIFAR10":
                inputs, _ = data
                gt_pieces, _, __ = batch_chunk_image(inputs, n)
                gt_pieces = gt_pieces.to(device)
            
            elif config.dataset == "sort-MNIST":
                random_pieces, gt_perm_list = data
                random_pieces, gt_perm_list = random_pieces.to(device), gt_perm_list.to(device)
                gt_pieces = utils.permute_image_perm_list(gt_perm_list, random_pieces)

            else:
                raise NotImplementedError
            
            # gt_pieces is x_start
            loss = diffusion_utils.training_losses(gt_pieces, reverse_model)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if config.train.record_wandb:
                wandb.log({"diffusion_loss": loss})

            if i > 0 and i % 50 == 0:
                logger.info(f'Epoch {epoch}, minibatch {i}, current loss = {loss.item()}')

            if scheduler:
                scheduler.step()

        save_checkpoint(config, epoch, reverse_model, optimizer, scheduler, ckpt_dir)

def train(config, ckpt_dir):
    model = init_model(config)
    model.train()

    train_loader, _ = get_loader(config)

    optimizer = optim.AdamW(
        params=model.parameters(),
        lr=config.train.learning_rate,
        betas=(0.9, 0.98),
        eps=1.0e-9,
    )
    num_training_steps = len(train_loader) * config.train.epochs
    scheduler = get_schedule_fn(
        config.train.scheduler,
        num_training_steps=num_training_steps, # cosine-decay scheduler only
        warmup_steps=config.train.warmup_steps, # transformer scheduler only
        dim_embed=config.transformer.embd_dim # transformer scheduler only
    )(optimizer)

    start_epoch = 0
    if config.train.resume:
        start_epoch, model, optimizer, scheduler = load_checkpoint(config, model, optimizer, scheduler, ckpt_dir)

    if config.train.record_wandb:
        project_name = "latent-permutation-diffusion"
        project_name = config.dataset
        wandb.init(
            project=project_name,
            name=config.train.run_name,
            config=config.to_dict()
        )

    diffusion_utils = init_diffusion_utils(config)

    train_diffusion(diffusion_utils, model, optimizer, scheduler, train_loader, config, start_epoch, ckpt_dir)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if config.train.record_wandb:
        wandb.run.summary["total_params"] = total_params
    logger.info(f"Total number of parameters: {total_params}")

    torch.save(model.state_dict(), f"{ckpt_dir}/{config.train.run_name}.pth")

    return model

@torch.inference_mode()
def eval_image_dataset(config, ckpt_dir):
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    model = init_model(config)
    model.load_state_dict(torch.load(f"{ckpt_dir}/{config.train.run_name}.pth"))

    if config.train.record_wandb and config.eval_only:
        model.train()
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        wandb.run.summary["total_params"] = total_params

    model.eval()

    diffusion_utils = init_diffusion_utils(config)

    _, test_loader = get_loader(config)

    kendall_taus = []
    mean_mse = []
    mean_rmse = []
    mean_l1 = []
    l1_loss = torch.nn.L1Loss()

    perm_all_correct = 0
    perm_piece_correct = 0

    image_all_correct = 0
    image_piece_correct = 0

    total_images = 0
    total_pieces = 0

    wrong_cnt = 0

    for _ in range(1):
        for i, data in enumerate(test_loader):
            logger.info(f"Mini-batch {i}")
            if config.dataset == "sort-MNIST":
                random_pieces, gt_perm_list = data
                random_pieces, gt_perm_list = random_pieces.to(device), gt_perm_list.to(device)
                pieces = utils.permute_image_perm_list(gt_perm_list, random_pieces)

            elif config.dataset == "unscramble-CIFAR10":
                inputs, _ = data
                pieces, random_pieces, gt_perm_list = batch_chunk_image(inputs, config.num_pieces)
                pieces, random_pieces, gt_perm_list = pieces.to(device), random_pieces.to(device), gt_perm_list.to(device)
                gt_perm_list = torch.argsort(gt_perm_list)

            else:
                raise NotImplementedError

            if config.beam_search:
                ordered_pieces, predicted_perm_list = diffusion_utils.p_sample_beam_search(random_pieces, model)
            else:
                ordered_pieces, predicted_perm_list = diffusion_utils.p_sample_loop(random_pieces, model, deterministic=True)
            
            # Obtain the Kendall-Tau correlation coefficient for the target
            # and predicted list of permutation matrices.
            for p1, p2 in zip(gt_perm_list, predicted_perm_list):
                p1, p2 = p1.cpu(), p2.cpu()
                kendall_taus.append(kendalltau(p1, p2)[0])
                perm_all_correct += torch.equal(p1, p2)
                perm_piece_correct += torch.eq(p1, p2).sum().item()

            total_images += gt_perm_list.size(0)
            total_pieces += torch.numel(gt_perm_list)

            compare_images = ~torch.isclose(pieces.flatten(start_dim=2), ordered_pieces.flatten(start_dim=2)) # shape [B, num_piecces**2, num_pixels]
            compare_result = (compare_images.sum(dim=(1, 2)) == 0)
            correct_images = compare_result.sum()

            if config.save_wrong_images:
                for i in range(compare_result.size(0)):
                    if not compare_result[i]:
                        utils.plot_image(ordered_pieces[i], config.num_pieces, save=f"{wrong_cnt}")
                        wrong_cnt += 1

            correct_image_pieces = (compare_images.sum(-1) == 0).sum()

            image_all_correct += correct_images.item()
            image_piece_correct += correct_image_pieces.item()

            mse = F.mse_loss(pieces, ordered_pieces, reduction="mean") # \sqrt( / bchw)
            mean_mse.append(mse.cpu())
            rmse = mse.sqrt()
            mean_rmse.append(rmse.cpu())
            l1 = l1_loss(pieces, ordered_pieces) # mean absolute error
            mean_l1.append(l1.cpu())

    mean_kendall_tau = np.mean(kendall_taus)
    mean_mse = torch.stack(mean_mse).mean()
    mean_rmse = torch.stack(mean_rmse).mean()
    mean_l1 = torch.stack(mean_l1).mean()

    perm_accuracy = perm_all_correct / total_images
    perm_prop_correct_pieces = perm_piece_correct / total_pieces

    image_accuracy = image_all_correct / total_images
    image_prop_correct_pieces = image_piece_correct / total_pieces

    if config.train.record_wandb:
        wandb.run.summary["mean_kendall_tau"] = mean_kendall_tau
        wandb.run.summary["mean_mse"] = mean_mse
        wandb.run.summary["mean_rmse"] = mean_rmse
        wandb.run.summary["mean_l1"] = mean_l1
        wandb.run.summary["permutation_accuracy"] = perm_accuracy
        wandb.run.summary["permutation_prop_correct_pieces"] = perm_prop_correct_pieces
        wandb.run.summary["pixel_wise_accuracy"] = image_accuracy
        wandb.run.summary["pixel_wise_prop_correct_pieces"] = image_prop_correct_pieces
        wandb.finish()

    logger.info(f"Mean Kendall-Tau: {mean_kendall_tau}")
    logger.info(f"Mean mse: {mean_mse}")
    logger.info(f"Mean root mse: {mean_rmse}")
    logger.info(f"Mean l1 loss: {mean_l1}")
    logger.info(f"Permutation Accuracy: {perm_accuracy}")
    logger.info(f"Permutation Prop. of correct pieces: {perm_prop_correct_pieces}")
    logger.info(f"Pixel-wise Accuracy: {image_accuracy}")
    logger.info(f"Pixel-wise Prop. of correct pieces: {image_prop_correct_pieces}")

@torch.inference_mode()
def eval(config, ckpt_dir):
    if config.train.record_wandb and config.eval_only:
        project_name = config.dataset
        wandb.init(
            project=project_name,
            name=f"EVAL_{config.train.run_name}",
            config=config.to_dict()
        )
    
    if config.dataset in ["sort-MNIST", "unscramble-CIFAR10"]:
        eval_image_dataset(config, ckpt_dir)
    else:
        raise NotImplementedError

@torch.inference_mode()
def demo(config):
    _, test_loader = get_loader(config)
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    data = next(iter(test_loader))

    model = init_model(config)
    model.load_state_dict(torch.load(f"./saved_models/{config.train.run_name}.pth"))
    model.eval()

    diffusion_utils = init_diffusion_utils(config)

    if config.dataset == "sort-MNIST":
        random_pieces, gt_perm_list = data
        random_pieces, gt_perm_list = random_pieces.to(device), gt_perm_list.to(device)
        pieces = utils.permute_image_perm_list(gt_perm_list, random_pieces)

    elif config.dataset == "unscramble-CIFAR10":
        inputs, _ = data
        pieces, random_pieces, gt_perm_list = batch_chunk_image(inputs, config.num_pieces)
        pieces, random_pieces, gt_perm_list = pieces.to(device), random_pieces.to(device), gt_perm_list.to(device)
        gt_perm_list = torch.argsort(gt_perm_list)

    else:
        raise NotImplementedError

    if config.beam_search:
        ordered_pieces, predicted_perm_list = diffusion_utils.p_sample_beam_search(random_pieces, model)
    else:
        ordered_pieces, predicted_perm_list = diffusion_utils.p_sample_loop(random_pieces, model, deterministic=True)

    # Select an image from the batch.
    batch_idx = 0

    if config.dataset == "sort-MNIST":
        utils.plot_image_serial(random_pieces[batch_idx], config.num_pieces)
        utils.plot_image_serial(ordered_pieces[batch_idx], config.num_pieces)

    elif config.dataset == "unscramble-CIFAR10":
        utils.plot_CIFAR10_image(random_pieces[batch_idx], config.num_pieces)
        utils.plot_CIFAR10_image(ordered_pieces[batch_idx], config.num_pieces)
        utils.plot_CIFAR10_image(pieces[batch_idx], config.num_pieces)
