import pdb
import os
import wandb
import torch
import logging
from tqdm import tqdm

from dl.src.trainers.base import Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FactorVAE_Trainer(Trainer):
    def __init__(self, args):
        super(FactorVAE_Trainer, self).__init__(args)

    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)


        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        disc, tc = None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        tr_disc, logging_disc = 0.0, 0.0
        tr_tc, logging_tc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        self.model.discriminator.zero_grad()

        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data1, data2, class_label) in enumerate(iteration):
                batch = data1.size(0)
                self.model.train()
                data1 = data1.to(device)
                data2 = data2.to(device)
                outputs = self.model(data1, self.loss_fn)

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                tc = outputs[0]["obj"]["tc"]

                total_loss = reconst + kld + self.args.gamma * tc

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()
                    tc = tc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_tc += tc.item()

                self.optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                # self.model.zero_grad()

                disc_output = self.model(data2, self.loss_fn)
                disc = disc_output[0]["obj"]["disc"]

                if self.args.n_gpu > 1:
                    disc = disc.mean()
                tr_disc += disc.item()

                self.disc_optimizer.zero_grad()
                disc.backward()
                self.disc_optimizer.step()
                self.disc_scheduler.step()

                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps
                    logs["04.DISCRIMINATOR"] = (tr_disc - logging_disc) / self.args.logging_steps
                    logs["05.TC"] = (tr_tc - logging_tc) / self.args.logging_steps

                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld
                    logging_tc = tr_tc
                    logging_disc = tr_disc
                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar

                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                    (
                        self.args.local_rank in [-1, 0]
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    )
                    or self.global_step == self.args.max_steps
                    or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    torch.save(
                        self.disc_optimizer.state_dict(), os.path.join(output_dir, "disc_optimizer.pt")
                    )
                    torch.save(
                        self.disc_scheduler.state_dict(), os.path.join(output_dir, "disc_scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)


                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return

    def eval(self):

        self.train_sampler = (
            SequentialSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.test_batch_size,
            drop_last=True,
            pin_memory=True,
        )
        self.global_step = 0
        self.t_total = len(self.train_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        disc, tc = None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        tr_disc, logging_disc = 0.0, 0.0
        tr_tc, logging_tc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        results = {}


        iteration = tqdm(self.train_dataloader, desc="Iteration")

        for i, (data1, data2, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                data1 = data1.to(device)
                data2 = data2.to(device)
                outputs = self.model(data1, self.loss_fn)

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                tc = outputs[0]["obj"]["tc"]

                outputs = self.model(data2, self.loss_fn)
                disc = outputs[0]["obj"]["disc"]


                total_loss = reconst + kld + disc + self.args.gamma * tc


                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst_err = reconst_err.mean()
                    kld_err = kld_err.mean()
                    disc = disc.mean()
                    tc = tc.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()
                tr_disc = disc.item()
                tr_tc = tc.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total
        results["disc"] = tr_disc / self.t_total
        results["tc"] = tr_tc / self.t_total

        return results

    def qualitative(self):

        logger.info("***********Qualitative Analysis***********")
        train_sampler = RandomSampler(self.dataset)
        train_dataloader = DataLoader(
            self.dataset, sampler=train_sampler, batch_size=self.args.quali_sampling
        )
        iteration = tqdm(train_dataloader, desc="Iteration")

        for k, (data, data2, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                new_zs = []
                # imgs, factors = dataset
                data = data.to(device)
                # dataset = dataset.to(device)
                outputs = self.model.encoder(data)

                z = outputs[0]  # (Batch, dimension)
                mean = outputs[1]
                logvar = outputs[2]
                z = self.changed_latent_vector_value(z, interval=self.args.interval)

                outputs = self.model.decoder(z)
                reconst = outputs[
                    0
                ]  # outputs[2][0] if "commutative" not in args.model_type else outputs[2] #reconstruction imgs
                new_outputs, gen_outputs = [], []
                for i in range(data.size(0)):
                    new_outputs.append(data[i].unsqueeze(0))
                    new_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                    gen_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                new_outputs = torch.cat(new_outputs, dim=0)
                gen_outputs = torch.cat(gen_outputs, dim=0)
                if k == 0:
                    break
        return new_outputs, gen_outputs

    def changed_latent_vector_value(self, latent_vector, interval):
        dim = latent_vector.size(0)  # Batch == dimension

        mask = torch.ones_like(latent_vector) - torch.eye(dim).to(
            device
        )  # diagonal is zeros and others are ones

        # set latent z values from -2 to 2
        latent_vector = latent_vector * mask
        interval = torch.arange(-2, 2.1, 4 / (interval - 1)).to(device)  # [latent_dim]
        interval = interval.unsqueeze(-1)  # [latent_dim, 1]
        interval = interval.unsqueeze(2).expand(
            *interval.size(), interval.size(1)
        )  # [latent_dim, 1, 1]
        interval = interval * torch.eye(dim).to(device)
        latent_vector = (
            (latent_vector + interval).permute(1, 0, 2).reshape(-1, dim)
        )  # (Batch * interval , dim)
        return latent_vector
