import pdb
import os
import wandb
import torch
import logging
from tqdm import tqdm

from cg.src.trainers.base import Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from cg.src.analysis_tools.dci_matrix import save_dci_matrix

from cg.src.constants import BASE_DATA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BetaTCVAE_Trainer(Trainer):
    def __init__(self, args):
        super(BetaTCVAE_Trainer, self).__init__(args)

    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.trainset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)


        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        mi, tc = None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        tr_mi, logging_mi = 0.0, 0.0
        tr_tc, logging_tc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data, class_label) in enumerate(iteration):
                batch = data.size(0)
                self.model.train()
                data = data.to(device)
                outputs = self.model(data, self.loss_fn)

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                mi, tc = outputs[0]["obj"]["mi"], outputs[0]["obj"]["tc"]

                total_loss = reconst + \
                            self.args.alpha * mi + \
                            self.args.beta * tc + \
                            self.args.gamma * kld


                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()
                    mi = mi.mean()
                    tc = tc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_mi += mi.item()
                tr_tc += tc.item()


                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()
                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps
                    logs["04.MI"] = (tr_mi - logging_mi) / self.args.logging_steps
                    logs["05.TC"] = (tr_tc - logging_tc) / self.args.logging_steps

                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld
                    logging_tc = tr_tc
                    logging_mi = tr_mi
                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar

                    if self.global_step % 10000 == 0:
                        results = self.eval()
                        logs["06.Eval_Reconstruction"] = results["reconst"]

                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                    (
                        self.args.local_rank in [-1, 0]
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    )
                    or self.global_step == self.args.max_steps
                    or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)


                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return

    def eval(self):

        # self.train_sampler = (
        #     SequentialSampler(self.testset)
        #     if self.args.local_rank == -1
        #     else DistributedSampler(self.testset)
        # )
        # self.train_dataloader = DataLoader(
        #     self.dataset,
        #     sampler=self.train_sampler,
        #     batch_size=self.args.test_batch_size,
        #     drop_last=True,
        #     pin_memory=True,
        # )
        # self.global_step = 0
        self.t_total = len(self.test_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.testset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        mi, tc = None, None
        fid = None
        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0
        tr_fid, logging_fid = 0.0, 0.0

        tr_mi, logging_mi = 0.0, 0.0
        tr_tc, logging_tc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        results = {}


        iteration = tqdm(self.test_dataloader, desc="Iteration")

        for i, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                data = data.to(device)
                outputs = self.model(data, self.loss_fn)
                reconst_imgs = outputs[2][0]
                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                mi, tc = outputs[0]["obj"]["mi"], outputs[0]["obj"]["tc"]

                total_loss = reconst + \
                            self.args.alpha * mi + \
                            self.args.beta * tc + \
                            self.args.gamma * kld


                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst_err = reconst_err.mean()
                    kld_err = kld_err.mean()
                    mi = mi.mean()
                    tc = tc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_mi += mi.item()
                tr_tc += tc.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total
        results["mi"] = tr_mi / self.t_total
        results["tc"] = tr_tc / self.t_total

        return results


    def qualitative(self):

        logger.info("***********Qualitative Analysis***********")
        train_sampler = SequentialSampler(self.testset)
        train_dataloader = DataLoader(
            self.testset, sampler=train_sampler, batch_size=64
        )
        iteration = tqdm(train_dataloader, desc="Iteration")
        imgs, gen_imgs, reconst_errs = [], [], []
        for k, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                new_zs = []
                # imgs, factors = dataset
                data = data.to(device)
                # dataset = dataset.to(device)
                outputs = self.model(data, self.loss_fn)
                reconst_imgs = outputs[2][0]
                gen_imgs.append(reconst_imgs.detach().cpu())
                imgs.append(data.detach().cpu())
                for i in range(reconst_imgs.size(0)):
                    reconst_err = self.loss_fn(reconst_imgs[i], data[i]).detach().cpu()
                    reconst_errs.append(reconst_err)

        imgs = torch.cat(imgs, dim=0)
        gen_imgs = torch.cat(gen_imgs, dim=0)
        reconst_errs = torch.stack(reconst_errs, dim=0)

        # best cases
        _, top_idx = torch.topk(reconst_errs, 10, largest=False)
        top_idx = top_idx.numpy()

        ori_imgs= imgs[top_idx]
        top_imgs = gen_imgs[top_idx]
        best = torch.cat([ori_imgs, top_imgs], dim=0)

        # worst cases
        _, top_idx = torch.topk(reconst_errs, 10)
        top_idx = top_idx.numpy()

        ori_imgs = imgs[top_idx]
        top_imgs = gen_imgs[top_idx]
        worst = torch.cat([ori_imgs, top_imgs], dim=0)

        return best, worst


    def analysis(self):
        save_dci_matrix(dataset=self.trainset,
                        model=self.model,
                        batch_size=64,
                        iteration=100,
                        loss_fn=self.loss_fn,
                        matrix_dir=self.save_file,
                        args=self.args)




        return
















