import pdb
import os
import math
import wandb
import torch
import logging
from tqdm import tqdm

from dl.src.constants import FACTOR_INFORM
from dl.src.trainers.cmcs.base import CMCS_Trainer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CMCS_Super_Trainger(CMCS_Trainer):
    def __init__(self, args):
        super(CMCS_Super_Trainger, self).__init__(args)

    def train(self):
        logger.info("***********Running Model Training***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        code, canonical, decequiv = None, None, None
        label_loss, label_acc = None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        tr_code, logging_code = 0.0, 0.0
        tr_canonical, logging_canonical = 0.0, 0.0
        tr_decequiv, logging_decequiv = 0.0, 0.0

        tr_label_loss, logging_label_loss = 0.0, 0.0
        tr_label_acc, logging_label_acc = 0.0, 0.0

        iteration_per_epoch = len(self.train_dataloader)

        self.model.zero_grad()
        wandb.init(project=self.args.project_name, name=self.run_file, entity=self.args.entity)

        for epoch in tqdm(range(self.args.num_epoch), desc="Epoch"):
            iteration = tqdm(self.train_dataloader, desc="Iteration")

            for i, (data, class_label) in enumerate(iteration):

                batch = data.size(0)
                # model.train()
                self.model.freeze()
                data = data.to(device)
                class_label = class_label.to(device)
                outputs = self.model(data, self.loss_fn, class_label.long())

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                # enc_equiv = outputs[0]["obj"]["enc_equiv"]
                decequiv = outputs[0]["obj"]["dec_equiv"]
                code = outputs[0]["obj"]["code_loss"]
                canonical = outputs[0]["obj"]["canonical_loss"]

                label_loss = outputs[0]["obj"]["label_loss"]
                label_acc = outputs[0]["obj"]["label_acc"]

                total_loss = (
                        reconst
                        + self.args.beta * label_loss
                        + self.args.alpha * code
                        + self.args.gamma * canonical
                )

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()

                    decequiv = decequiv.mean()
                    code = code.mean()
                    canonical = canonical.mean()

                    label_loss = label_loss.mean()
                    label_acc = label_acc.mean()

                elbo = -(reconst + kld)
                tr_total_loss += total_loss.item()
                tr_elbo += elbo.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()

                tr_decequiv += decequiv.item()
                tr_code += code.item()
                tr_canonical += canonical.item()

                tr_label_loss += label_loss.item()
                tr_label_acc += label_acc.item()


                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()
                self.global_step += 1


                if (
                        self.args.local_rank in [-1, 0]
                        and self.args.logging_steps > 0
                        and self.global_step % (self.args.logging_steps) == 0
                ):
                    logs = {}
                    logs["00.ELBO"] = (tr_elbo - logging_elbo) / self.args.logging_steps
                    logs["01.Total_Loss"] = (
                        tr_total_loss - logging_total_loss
                    ) / self.args.logging_steps
                    logs["02.Reconstruction_Loss"] = (
                        tr_reconst - logging_reconst
                    ) / self.args.logging_steps
                    logs["03.Kullback-Reibler_Loss"] = (
                        tr_kld - logging_kld
                    ) / self.args.logging_steps

                    logs["04.Decoder_Equivariant"] = (
                        tr_decequiv - logging_decequiv
                    ) / self.args.logging_steps
                    logs["05.Code_Loss"] = (
                        tr_code - logging_code
                    ) / self.args.logging_steps
                    logs["06.canonical_loss"] = (
                        tr_canonical - logging_canonical
                    ) / self.args.logging_steps
                    logs["07.Label_loss"] = (
                        tr_label_loss - logging_label_loss
                    ) / self.args.logging_steps
                    logs["08.Label_ACC"] = (
                        tr_label_acc - logging_label_acc
                    ) / self.args.logging_steps


                    logging_elbo = tr_elbo
                    logging_total_loss = tr_total_loss
                    logging_reconst = tr_reconst
                    logging_kld = tr_kld

                    logging_decequiv = tr_decequiv
                    logging_code = tr_code
                    logging_canonical = tr_canonical

                    logging_label_loss = tr_label_loss
                    logging_label_acc = tr_label_acc

                    learning_rate_scalar = self.scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    wandb.log(logs)

                # Write model parameters (checkpoint)
                if (
                    (
                        self.args.local_rank in [-1, 0]
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    )
                    or self.global_step == self.args.max_steps
                    or self.global_step == iteration_per_epoch * self.args.num_epoch
                ):  # save in last step
                    output_dir = os.path.join(
                        self.output_dir,
                        self.args.model_type,
                        self.save_file,
                        "checkpoint-{}".format(self.global_step),
                    )
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = self.model.module if hasattr(self.model, "module") else self.model
                    torch.save(
                        model_to_save.state_dict(), os.path.join(output_dir, "model.pt")
                    )
                    torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)
                    torch.save(
                        self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
                    )
                    torch.save(
                        self.scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
                    )
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)


                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                    iteration.close()
                    return

        wandb.finish()
        return


    def eval(self):

        self.train_sampler = (
            SequentialSampler(self.dataset)
            if self.args.local_rank == -1
            else DistributedSampler(self.dataset)
        )
        self.train_dataloader = DataLoader(
            self.dataset,
            sampler=self.train_sampler,
            batch_size=self.args.test_batch_size,
            drop_last=True,
            pin_memory=True,
        )
        self.global_step = 0
        self.t_total = len(self.train_dataloader)

        logger.info("***********Running Model Evaluation***********")
        logger.info(" Num examples = %d", len(self.dataset))
        logger.info(" Num Epochs = %d", self.args.num_epoch)
        logger.info(" Batch size per GPU = %d", self.args.per_gpu_train_batch_size)
        logger.info(
            " Total train batch size = %d",
            self.args.train_batch_size
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1),
        )
        logger.info("Total optimization steps = %d", self.args.t_total)

        # loss
        elbo, reconst, kld, total_loss = None, None, None, None
        code, canonical, decequiv = None, None, None
        label_loss, label_acc = None, None

        # logging loss
        tr_elbo, logging_elbo = 0.0, 0.0
        tr_reconst, logging_reconst = 0.0, 0.0
        tr_kld, logging_kld = 0.0, 0.0
        tr_total_loss, logging_total_loss = 0.0, 0.0

        tr_code, logging_code = 0.0, 0.0
        tr_canonical, logging_canonical = 0.0, 0.0
        tr_decequiv, logging_decequiv = 0.0, 0.0

        tr_label_loss, logging_label_loss = 0.0, 0.0
        tr_label_acc, logging_label_acc = 0.0, 0.0

        results = {}


        iteration = tqdm(self.train_dataloader, desc="Iteration")

        for i, (data, class_label) in enumerate(iteration):
            with torch.no_grad():
                self.model.eval()
                # data = torch.cat([data, data2], dim=0)
                # class_label = torch.cat([class_label, class_label2], dim=0)
                data = data.to(device)
                class_label = class_label.to(device)
                outputs = self.model(data, self.loss_fn, class_label.long())

                reconst, kld = (
                    outputs[0]["obj"]["reconst"],
                    outputs[0]["obj"]["kld"],
                )
                # enc_equiv = outputs[0]["obj"]["enc_equiv"]
                decequiv = outputs[0]["obj"]["dec_equiv"]
                code = outputs[0]["obj"]["code_loss"]
                canonical = outputs[0]["obj"]["canonical_loss"]

                label_loss = outputs[0]["obj"]["label_loss"]
                label_acc = outputs[0]["obj"]["label_acc"]

                total_loss = (
                        reconst
                        + self.args.beta * label_loss
                        + self.args.alpha * code
                        + self.args.gamma * canonical
                )

                if self.args.n_gpu > 1:
                    total_loss = total_loss.mean()
                    reconst = reconst.mean()
                    kld = kld.mean()

                    decequiv = decequiv.mean()
                    code = code.mean()
                    canonical = canonical.mean()

                    label_loss = label_loss.mean()
                    label_acc = label_acc.mean()

                elbo = -(reconst + kld)
                tr_total_loss += total_loss.item()
                tr_elbo += elbo.item()
                tr_reconst += reconst.item()
                tr_kld += kld.item()

                tr_decequiv += decequiv.item()
                tr_code += code.item()
                tr_canonical += canonical.item()

                tr_label_loss += label_loss.item()
                tr_label_acc += label_acc.item()

        results["elbo"] = tr_elbo / self.t_total
        results["obj"] = tr_total_loss / self.t_total
        results["reconst"] = tr_reconst / self.t_total
        results["kld"] = tr_kld / self.t_total
        results["code_loss"] = tr_code / self.t_total
        results["canonical_loss"] = tr_canonical / self.t_total
        results["decoder_equiv"] = tr_decequiv / self.t_total
        results["label_loss"] = tr_label_loss / self.t_total
        results["label_acc"] = tr_label_acc / self.t_total

        return results



    def qualitative(self):

        logger.info("***********Qualitative Analysis***********")
        train_sampler = RandomSampler(self.dataset)
        train_dataloader = DataLoader(
            self.dataset, sampler=train_sampler, batch_size=self.args.quali_sampling
        )
        iteration = tqdm(train_dataloader, desc="Iteration")

        zs, reconst_errs = [], []
        with torch.no_grad():
            self.model.eval()

            for k, (data, class_label) in enumerate(iteration):

                new_zs = []
                data = data.to(device)
                class_label = class_label.to(device)
                outputs = self.model.encoder(data)

                z = outputs[1]  # (Batch, dimension)
                theta = self.model.real_to_theta(z)
                theta = self.model.select_code(theta)

                zs.append(theta)
                theta = self.changed_latent_vector_value(self.model, theta, class_label, args=self.args)
                outputs = self.model.decoder(theta)
                reconst = outputs[
                    0
                ]  # outputs[2][0] if "commutative" not in args.model_type else outputs[2] #reconstruction imgs

                new_outputs, gen_outputs = [], []
                for i in range(data.size(0)):
                    new_outputs.append(data[i].unsqueeze(0))
                    new_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                    gen_outputs.append(
                        reconst[i * self.args.interval: (i + 1) * self.args.interval, :, :, :]
                    )
                new_outputs = torch.cat(new_outputs, dim=0)
                gen_outputs = torch.cat(gen_outputs, dim=0)

                if k == 0:
                    break
        return new_outputs, gen_outputs

    def changed_latent_vector_value(self, model, latent_vector, class_label, args):

        batch, dim = latent_vector.size()  # Batch == dimension [B, D]

        scale = (
                model.n / 100.0 * torch.ones(size=(args.latent_dim,)).to(device)
        )

        active_label = torch.diag(
            class_label[: len(FACTOR_INFORM[args.dataset])], 0
        ).unsqueeze(-1)
        interval = torch.linspace(start=0, end=9, steps=args.interval).to(device)

        transform = (active_label + interval) % FACTOR_INFORM[args.dataset].unsqueeze(
            -1
        ).to(
            device
        )  # [|A.D|, 10]
        transform = transform - active_label
        transform = transform.transpose(-1, -2)

        if dim - len(FACTOR_INFORM[args.dataset]) != 0:
            interval = interval.repeat(dim - len(FACTOR_INFORM[args.dataset]), 1).transpose(
                -1, -2
            )  # [10, |A.D|]
            symmetry = torch.cat([transform, interval], dim=-1)  # [10, |A.D|*2]
        else:
            symmetry = transform  # [10, |A.D|]

        symmetry = scale * symmetry * 2 * math.pi / model.n  # [interval, latent_dim]
        symmetry = self.diagonal_3d(symmetry, args.interval, batch, dim)  # [B, I, D]

        latent_vector = latent_vector.unsqueeze(1) + symmetry  # [B, I, D]
        latent_vector = latent_vector.view(-1, dim)
        return latent_vector

    def diagonal_3d(self, z, interval, batch, dim):
        tensor = torch.zeros(size=(batch, interval, dim)).to(device)

        for i in range(dim):
            tensor[i, :, i] = z[:, i]
        return tensor
