import pdb
import os
import wandb
import torch
import logging
from tqdm import tqdm

from cg.src.trainers.base import Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from cg.src.analysis_tools.dci_matrix import save_dci_matrix

from cg.src.constants import BASE_DATA


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BetaMAGAVAE_Trainer(Trainer):
    def __init__(self, args):
        super(BetaMAGAVAE_Trainer, self).__init__(args)

    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.trainset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_latent_loss, logging_latent_loss = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data, class_label) in enumerate(iteration):
                batch = data.size(0)
                self.model.train()
                data = data.to(device)
                outputs = self.model(data, self.loss_fn)

                reconst, kld, latent_loss = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                    outputs[0]["obj"]["latent_loss"],
                )

                total_loss = reconst + self.args.beta * kld + self.args.gamma * latent_loss

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()
                    latent_loss = latent_loss.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_latent_loss += latent_loss.item()


                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()
                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps
                    logs["04.Latent_Loss"] = (
                            tr_latent_loss - logging_latent_loss) / self.args.logging_steps

                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld
                    logging_latent_loss = tr_latent_loss
                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar

                    if self.global_step % 10000 == 0:
                        results = self.eval()
                        logs["05.Eval_Reconstruction"] = results["reconst"]

                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                        (
                                self.args.local_rank in [-1, 0]
                                and self.args.save_steps > 0
                                and self.global_step % self.args.save_steps == 0
                        )
                        or self.global_step == self.args.max_steps
                        or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return


    def eval(self):
        self.t_total = len(self.test_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.testset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_latent_loss, logging_latent_loss = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0
        tr_fid, logging_fid = 0.0, 0.0
        results = {}

        ## add pivot
        low, high = 0, len(self.trainset.data)
        idx = torch.randint(low=low, high=high, size=(self.args.test_batch_size,))
        pivot_data, pivot_class = [], []
        for id in idx:
            temp_data, temp_class = self.trainset.__getitem__(id.item())
            pivot_data.append(temp_data)
            pivot_class.append(temp_class)
        pivot_data = torch.stack(pivot_data, dim=0)
        pivot_class = torch.stack(pivot_class, dim=0)

        iteration = tqdm(self.test_dataloader, desc="Iteration")

        for i, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                batch = data.size(0)
                data = torch.cat([pivot_data, data], dim=0)
                data = data.to(device)
                # class_label = torch.cat([pivot_class, class_label], dim=0)
                # class_label = class_label.to(device)

                outputs = self.model(data, self.loss_fn)

                reconst, kld, latent_loss = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                    outputs[0]["obj"]["latent_loss"],
                )

                total_loss = reconst + self.args.beta * kld + self.args.gamma * latent_loss

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()
                    latent_loss = latent_loss.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_latent_loss += latent_loss.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total
        results["latent_loss"] = tr_latent_loss / self.t_total

        return results





    def qualitative(self):

        # select pivot data
        pivot_data = []
        low, high = 0, len(self.trainset.data)
        idx = torch.randint(low=low, high=high, size=(64,))
        for id in idx:
            pivot_data.append(self.trainset.transforms(self.trainset.data[id]))
        # pivot_data = self.trainset.transforms(self.trainset.data[idx]).to(device)
        pivot_data = torch.stack(pivot_data, dim=0)
        pivot_class = torch.Tensor(self.trainset.latents_classes[idx])

        logger.info("***********Qualitative Analysis***********")
        train_sampler = SequentialSampler(self.testset)
        train_dataloader = DataLoader(
            self.testset, sampler=train_sampler, batch_size=64, drop_last=True
        )
        iteration = tqdm(train_dataloader, desc="Iteration")
        imgs, gen_imgs, reconst_errs = [], [], []
        for k, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                new_zs = []
                # imgs, factors = dataset
                new_data = torch.cat([pivot_data, data], dim=0).to(device)
                class_label = torch.cat([pivot_class, class_label], dim=0).to(device)
                # dataset = dataset.to(device)
                outputs = self.model(new_data, self.loss_fn)
                batch = new_data.size(0)
                reconst_imgs = outputs[2][0][batch // 2: batch]

                gen_imgs.append(reconst_imgs.detach().cpu())
                imgs.append(data.detach().cpu())
                for i in range(reconst_imgs.size(0)):
                    reconst_err = self.loss_fn(reconst_imgs[i], data[i].to(device)).detach().cpu()
                    reconst_errs.append(reconst_err)

            imgs = torch.cat(imgs, dim=0)
            gen_imgs = torch.cat(gen_imgs, dim=0)
            reconst_errs = torch.stack(reconst_errs, dim=0)

            # best cases
            _, top_idx = torch.topk(reconst_errs, 10, largest=False)
            top_idx = top_idx.numpy()

            ori_imgs = imgs[top_idx]
            top_imgs = gen_imgs[top_idx]
            best = torch.cat([ori_imgs, top_imgs], dim=0)

            # worst cases
            _, top_idx = torch.topk(reconst_errs, 10)
            top_idx = top_idx.numpy()

            ori_imgs = imgs[top_idx]
            top_imgs = gen_imgs[top_idx]
            worst = torch.cat([ori_imgs, top_imgs], dim=0)

            return best, worst


    def analysis(self):
        save_dci_matrix(dataset=self.trainset,
                        model=self.model,
                        batch_size=64,
                        iteration=100,
                        loss_fn=self.loss_fn,
                        matrix_dir=self.save_file,
                        args=self.args)



        return
