import pdb
import os
import wandb
import torch
import logging
from tqdm import tqdm

from cg.src.constants import FACTOR_INFORM
from cg.src.trainers.cmcs.base import CMCS_Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from cg.src.analysis_tools.dci_matrix import save_dci_matrix

from cg.src.constants import BASE_DATA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CMCS_Super_Trainger(CMCS_Trainer):
    def __init__(self, args):
        super(CMCS_Super_Trainger, self).__init__(args)

    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.trainset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)


        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        code, canonical, decequiv = None, None, None
        label_loss, label_acc = None, None

        tr_code, logging_code = 0.0, 0.0
        tr_canonical, logging_canonical = 0.0, 0.0
        tr_decequiv, logging_decequiv = 0.0, 0.0

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0
        tr_label_loss, logging_label_loss = 0.0, 0.0
        tr_label_acc, logging_label_acc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)
        # wandb.require("legacy-service")
        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data, class_label) in enumerate(iteration):
                batch = data.size(0)
                self.model.freeze()
                data = data.to(device)
                class_label = class_label.to(device)

                outputs = self.model(data, self.loss_fn, class_label.long())

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                decequiv = outputs[0]["obj"]["dec_equiv"]
                code = outputs[0]["obj"]["code_loss"]
                canonical = outputs[0]["obj"]["canonical_loss"]
                label_loss = outputs[0]["obj"]["label_loss"]
                label_acc = outputs[0]["obj"]["label_acc"]

                total_loss = (
                        reconst
                        + self.args.beta * label_loss
                        + self.args.alpha * code
                        + self.args.gamma * canonical
                        + self.args.lamb * decequiv
                )

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()
                    decequiv = decequiv.mean()
                    code = code.mean()
                    canonical = canonical.mean()
                    label_loss = label_loss.mean()
                    label_acc = label_acc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_decequiv += decequiv.item()
                tr_code += code.item()
                tr_canonical += canonical.item()
                tr_label_loss += label_loss.item()
                tr_label_acc += label_acc.item()


                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()
                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps
                    logs["04.Decoder_Equivariant"] = (
                        tr_decequiv - logging_decequiv
                    ) / self.args.logging_steps
                    logs["05.Code_Loss"] = (
                        tr_code - logging_code
                    ) / self.args.logging_steps
                    logs["06.canonical_loss"] = (
                        tr_canonical - logging_canonical
                    ) / self.args.logging_steps
                    logs["07.Label_loss"] = (
                        tr_label_loss - logging_label_loss
                    ) / self.args.logging_steps
                    logs["08.Label_ACC"] = (
                        tr_label_acc - logging_label_acc
                    ) / self.args.logging_steps

                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld
                    logging_decequiv = tr_decequiv
                    logging_code = tr_code
                    logging_canonical = tr_canonical
                    logging_label_loss = tr_label_loss
                    logging_label_acc = tr_label_acc
                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar

                    if self.global_step % 10000 == 0:
                        results = self.eval()
                        logs["09.Eval_Reconstruction"] = results["reconst"]

                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                    (
                        self.args.local_rank in [-1, 0]
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    )
                    or self.global_step == self.args.max_steps
                    or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)


                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return

    def eval(self):

        # self.train_sampler = (
        #     SequentialSampler(self.testset)
        #     if self.args.local_rank == -1
        #     else DistributedSampler(self.testset)
        # )
        # self.train_dataloader = DataLoader(
        #     self.dataset,
        #     sampler=self.train_sampler,
        #     batch_size=self.args.test_batch_size,
        #     drop_last=True,
        #     pin_memory=True,
        # )
        # self.global_step = 0
        self.t_total = len(self.test_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.testset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        code, canonical, decequiv = None, None, None
        label_loss, label_acc = None, None
        fid = None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0
        tr_fid, logging_fid = 0.0, 0.0
        tr_code, logging_code = 0.0, 0.0
        tr_canonical, logging_canonical = 0.0, 0.0
        tr_decequiv, logging_decequiv = 0.0, 0.0
        tr_label_loss, logging_label_loss = 0.0, 0.0
        tr_label_acc, logging_label_acc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        results = {}

        ## add pivot
        low, high = 0, len(self.trainset.data)
        idx = torch.randint(low=low, high=high, size=(self.args.test_batch_size,))
        pivot_data, pivot_class = [], []
        for id in idx:
            temp_data, temp_class = self.trainset.__getitem__(id.item())
            pivot_data.append(temp_data)
            pivot_class.append(temp_class)
        pivot_data = torch.stack(pivot_data, dim=0)
        pivot_class = torch.stack(pivot_class, dim=0)

        iteration = tqdm(self.test_dataloader, desc="Iteration")

        for i, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                batch = data.size(0)
                data = torch.cat([pivot_data, data], dim=0)
                data = data.to(device)
                class_label = torch.cat([pivot_class, class_label], dim=0)
                class_label = class_label.to(device)

                outputs = self.model(data, self.loss_fn, class_label.long())

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                decequiv = outputs[0]["obj"]["dec_equiv"]
                code = outputs[0]["obj"]["code_loss"]
                canonical = outputs[0]["obj"]["canonical_loss"]
                label_loss = outputs[0]["obj"]["label_loss"]
                label_acc = outputs[0]["obj"]["label_acc"]

                total_loss = (
                        reconst
                        + self.args.beta * label_loss
                        + self.args.alpha * code
                        + self.args.gamma * canonical
                        +self.args.lamb * decequiv
                )

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst_err = reconst_err.mean()
                    kld_err = kld_err.mean()
                    decequiv = decequiv.mean()
                    code = code.mean()
                    canonical = canonical.mean()
                    label_loss = label_loss.mean()
                    label_acc = label_acc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_decequiv += decequiv.item()
                tr_code += code.item()
                tr_canonical += canonical.item()
                tr_label_loss += label_loss.item()
                tr_label_acc += label_acc.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total
        results["code_loss"] = tr_code / self.t_total
        results["canonical_loss"] = tr_canonical / self.t_total
        results["decoder_equiv"] = tr_decequiv / self.t_total
        results["label_loss"] = tr_label_loss / self.t_total
        results["label_acc"] = tr_label_acc / self.t_total

        return results



    def qualitative(self):

        # select pivot data
        pivot_data = []
        low, high = 0, len(self.trainset.data)
        idx = torch.randint(low=low, high=high, size=(64,))
        for id in idx:
            pivot_data.append(self.trainset.transforms(self.trainset.data[id]))
        # pivot_data = self.trainset.transforms(self.trainset.data[idx]).to(device)
        pivot_data = torch.stack(pivot_data, dim=0)
        pivot_class = torch.Tensor(self.trainset.latents_classes[idx])

        logger.info("***********Qualitative Analysis***********")
        train_sampler = SequentialSampler(self.testset)
        train_dataloader = DataLoader(
            self.testset, sampler=train_sampler, batch_size=64, drop_last=True
        )
        iteration = tqdm(train_dataloader, desc="Iteration")
        imgs, gen_imgs, reconst_errs = [], [], []
        for k, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                new_zs = []
                # imgs, factors = dataset
                new_data = torch.cat([pivot_data, data], dim=0).to(device)
                class_label = torch.cat([pivot_class, class_label], dim=0).to(device)
                # dataset = dataset.to(device)
                outputs = self.model(new_data, self.loss_fn, class_label.long())

                batch = new_data.size(0)
                if self.args.cg == "r2e":
                    reconst_imgs = outputs[2][0][batch // 2: batch]
                else:
                    reconst_imgs = outputs[2][0][3 * batch // 2:]

                gen_imgs.append(reconst_imgs.detach().cpu())
                imgs.append(data.detach().cpu())
                for i in range(reconst_imgs.size(0)):
                    reconst_err = self.loss_fn(reconst_imgs[i], data[i].to(device)).detach().cpu()
                    reconst_errs.append(reconst_err)

        imgs = torch.cat(imgs, dim=0)
        gen_imgs = torch.cat(gen_imgs, dim=0)
        reconst_errs = torch.stack(reconst_errs, dim=0)

        # best cases
        _, top_idx = torch.topk(reconst_errs, 10, largest=False)
        top_idx = top_idx.numpy()

        ori_imgs= imgs[top_idx]
        top_imgs = gen_imgs[top_idx]
        best = torch.cat([ori_imgs, top_imgs], dim=0)

        # worst cases
        _, top_idx = torch.topk(reconst_errs, 10)
        top_idx = top_idx.numpy()

        ori_imgs = imgs[top_idx]
        top_imgs = gen_imgs[top_idx]
        worst = torch.cat([ori_imgs, top_imgs], dim=0)

        return best, worst


    def analysis(self):
        save_dci_matrix(dataset=self.trainset,
                        model=self.model,
                        batch_size=64,
                        iteration=100,
                        loss_fn=self.loss_fn,
                        matrix_dir=self.save_file,
                        args=self.args)



        return














