
import os
import wandb
import torch
import logging
from tqdm import tqdm
import numpy as np

from disent.trainer.base import Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CGEL_BetaVAE_Trainer(Trainer):
    def __init__(self, args):
        super(CGEL_BetaVAE_Trainer, self).__init__(args)


    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)


        # loss
        elbo, reconst, kld, total_loss = None, None, None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data, class_label) in enumerate(iteration):
                batch = data.size(0)
                self.model.train()
                data = data.to(device)
                outputs = self.model(data, self.loss_fn)

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )

                total_loss = reconst + self.args.beta * kld

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()


                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()
                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps

                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld
                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar

                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                    (
                        self.args.local_rank in [-1, 0]
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    )
                    or self.global_step == self.args.max_steps
                    or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)


                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return


    def eval(self):

        self.train_sampler = (
            SequentialSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.test_batch_size,
            drop_last=True,
            pin_memory=True,
        )
        self.global_step = 0
        self.t_total = len(self.train_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        results = {}


        iteration = tqdm(self.train_dataloader, desc="Iteration")

        for i, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                data = data.to(device)
                outputs = self.model(data, self.loss_fn)

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )

                total_loss = reconst + self.args.beta * kld


                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst_err = reconst_err.mean()
                    kld_err = kld_err.mean()

                elbo = -(total_loss)
                tr_elbo += elbo.item()
                tr_total_loss += total_loss.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total

        return results

    def qualitative(self):

        logger.info("***********Qualitative Analysis***********")
        train_sampler = RandomSampler(self.dataset)
        train_dataloader = DataLoader(
            self.dataset, sampler=train_sampler, batch_size=self.args.quali_sampling
        )
        iteration = tqdm(train_dataloader, desc="Iteration")

        for k, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                new_zs = []
                # imgs, factors = dataset
                data = data.to(device)
                # dataset = dataset.to(device)
                outputs = self.model.encoder(data)

                z = outputs[0]  # (Batch, dimension)
                mean = outputs[1]
                logvar = outputs[2]
                z = self.changed_latent_vector_value(z, interval=self.args.interval)

                outputs = self.model.decoder(z)
                reconst = outputs[
                    0
                ]  # outputs[2][0] if "commutative" not in args.model_type else outputs[2] #reconstruction imgs
                new_outputs, gen_outputs = [], []
                for i in range(data.size(0)):
                    new_outputs.append(data[i].unsqueeze(0))
                    new_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                    gen_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                new_outputs = torch.cat(new_outputs, dim=0)
                gen_outputs = torch.cat(gen_outputs, dim=0)
                if k == 0:
                    break
        return new_outputs, gen_outputs


    def changed_latent_vector_value(self, latent_vector, interval):
        dim = latent_vector.size(0)  # Batch == dimension

        mask = torch.ones_like(latent_vector) - torch.eye(dim).to(
            device
        )  # diagonal is zeros and others are ones

        # set latent z values from -2 to 2
        latent_vector = latent_vector * mask
        interval = torch.arange(-2, 2.1, 4 / (interval - 1)).to(device)  # [latent_dim]
        interval = interval.unsqueeze(-1)  # [latent_dim, 1]
        interval = interval.unsqueeze(2).expand(
            *interval.size(), interval.size(1)
        )  # [latent_dim, 1, 1]
        interval = interval * torch.eye(dim).to(device)
        latent_vector = (
            (latent_vector + interval).permute(1, 0, 2).reshape(-1, dim)
        )  # (Batch * interval , dim)
        return latent_vector
