import time

import torch
from torch import nn

from trainer_base import MainTrainerBase
from model.hsqvae import HierSQVAE
from model.rsqvae import ResSQVAE
from util import *
from third_party.piqa import PSNR, SSIM, LPIPS


class HierSQVAETrainer(MainTrainerBase):
    def __init__(self, cfgs, flgs, train_loader, val_loader, test_loader):
        super(HierSQVAETrainer, self).__init__(
            cfgs, flgs, train_loader, val_loader, test_loader)
        self.model = eval(
            "nn.DataParallel({}(cfgs, flgs).to(self.device))".format(self.cfgs.model.name))
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.cfgs.train.lr, amsgrad=False, betas=(0.9, 0.9))
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode="min", factor=0.5, patience=3,
            verbose=True, threshold=0.0001, threshold_mode="rel",
            cooldown=0, min_lr=0, eps=1e-08)
        self.plots = {
            "loss_train": [], "mse_train": [], "perplexity_train": [], "perplexity_bottom_train": [],
            "loss_val": [], "mse_val": [], "perplexity_val": [], "perplexity_bottom_val": [], "psnr_val": [], "ssim_val": [], "lpips_val": [],
            "loss_test": [], "mse_test": [], "perplexity_test": [], "perplexity_bottom_test": [], "psnr_test": [], "ssim_test": [], "lpips_test": []
        }
        
    def _train(self, epoch):
        train_loss = []
        ms_error = []
        perplexity = []
        self.model.train()
        start_time = time.time()
        for batch_idx, data in enumerate(self.train_loader):
            if self.cfgs.dataset.name in ['CelebA-HQ']:
                x = data.cuda()
            else:
                x = data[0].cuda()
            if batch_idx > self.cfgs.train.max_iteration:
                break
            if self.flgs.decay:
                step = (epoch - 1) * len(self.train_loader) + batch_idx + 1
                temperature_current = self._set_temperature(
                    step, self.cfgs.quantization.temperature)
                for quantizer in self.model.module.quantizer:
                    quantizer.set_temperature(temperature_current)
            _, loss = self.model(x, flg_train=True, flg_quant_det=False)
            self.optimizer.zero_grad()
            loss["all"].backward()
            self.optimizer.step()

            train_loss.append(loss["all"].detach().cpu().item())
            ms_error.append(loss["mse"].detach().cpu().item())
            if batch_idx == 0:
                perplexity = np.array(loss["perplexity"])
            else:
                perplexity += np.array(loss["perplexity"])

        result = {}
        result["loss"] = np.asarray(train_loss).mean(0)
        result["mse"] = np.array(ms_error).mean(0)
        result["perplexity"] = perplexity / (batch_idx + 1)
        self.print_loss(result, "train", time.time()-start_time)
                
        return result    
    
    def _test(self, mode="val"):
        self.model.eval()
        result = self._test_sub(True, mode)
        self.scheduler.step(result["loss"])
        if not self.flgs.noprint:
            if False:
                print("-- Prior variance:     [", end="")
                for i, res in enumerate(self.model.module.res):
                    print("Lyr {}: {:4.2f}".format(i + 1, self.model.module.log_param_q_scalar_p[i].exp().detach().cpu().item()), end=" ")
                print("]")
                print("-- Posterior variance: [", end="")
                for i, res in enumerate(self.model.module.res):
                    print("Lyr {}: {:4.2f}".format(i + 1, self.model.module.log_param_q_scalar_q[i].exp().detach().cpu().item()), end=" ")
                print("]")
            else:
                print("-- Posterior variance: [", end="")
                for i in range(len(self.cfgs.model.log_param_q_init)):
                    print("Lyr {}: {:4.2f}".format(i + 1, self.model.module.log_param_q_scalar_q[i].exp().detach().cpu().item()), end=" ")
                print("]")
        return result

    def _test_sub(self, flg_quant_det, mode="val"):
        psnr = PSNR()
        ssim = SSIM().cuda()
        lpips = LPIPS().cuda()
        n_samples = 0
        psnr_cur = 0.
        ssim_cur = 0.
        lpips_cur = 0.

        test_loss = []
        ms_error = []
        perplexity = []
        perplexity_bottom = []
        data_loader = eval("self.{}_loader".format(mode))
        start_time = time.time()
        with torch.no_grad():
            for batch_idx, data in enumerate(data_loader):
                if self.cfgs.dataset.name in ['CelebA-HQ']:
                    x = data.cuda()
                else:
                    x = data[0].cuda()
                x_rec, loss = self.model(x, flg_quant_det=flg_quant_det)

                n_samples += x.shape[0]
                psnr_cur += psnr(x, x_rec).sum()
                ssim_cur += ssim(x, x_rec).sum()
                lpips_cur += lpips(x, x_rec).sum()

                test_loss.append(loss["all"].item())
                ms_error.append(loss["mse"].item())
                if batch_idx == 0:
                    perplexity = np.array(loss["perplexity"])
                else:
                    perplexity += np.array(loss["perplexity"])
        result = {}
        result["psnr"] = psnr_cur / n_samples
        result["ssim"] = ssim_cur / n_samples
        result["lpips"] = lpips_cur / n_samples
        result["loss"] = np.asarray(test_loss).mean(0)
        result["mse"] = np.array(ms_error).mean(0)
        result["perplexity"] = perplexity / (batch_idx + 1)
        self.print_loss_val(result, mode, time.time()-start_time)

        return result
    
    def generate_reconstructions(self, filename, nrows=4, ncols=8):
        self._generate_reconstructions_continuous(filename, nrows=nrows, ncols=ncols)
    
    def print_loss(self, result, mode, time_interval):
        if not self.flgs.noprint:
            print(mode.capitalize().ljust(16) +
                "Loss: {:5.4f}, MSE: {:5.4f}, "
                .format(
                    result["loss"], result["mse"]
                ), end="")
            print("Perplexity: [", end="")
            for i, perp in enumerate(result["perplexity"]):
                print("Lyr {}: {:4.2f}".format(i + 1, perp), end=" ")
            print("], Time: {:3.1f} sec".format(time_interval))
    
    def print_loss_val(self, result, mode, time_interval):
        if not self.flgs.noprint:
            print(mode.capitalize().ljust(16) +
                "Loss: {:5.4f}, MSE: {:5.4f}, PSNR: {:5.4f}, SSIM: {:5.4f}, LPIPS: {:5.4f}, "
                .format(
                    result["loss"], result["mse"], result["psnr"], result["ssim"], result["lpips"]
                ), end="")
            print("Perplexity: [", end="")
            for i, perp in enumerate(result["perplexity"]):
                print("Lyr {}: {:4.2f}".format(i + 1, perp), end=" ")
            print("], Time: {:3.1f} sec".format(time_interval))


