"""N-NPSSM"""

import torch
import logging
import torch.nn as nn
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from .components.beta import BetaVAE_MLP_nonoise
from .components.transition import NPTransitionPrior
from .metrics.correlation import compute_mcc


class StationaryProcess(pl.LightningModule):
    def __init__(
        self, 
        input_dim,
        length,
        z_dim, 
        lag,
        hidden_dim=128,
        trans_prior='np',
        lr=1e-4,
        infer_mode='f',
        beta=0.0025,
        gamma=0.0075,
        decoder_dist='gaussian',
        fixed_noise=False,
        correlation='Pearson'):

        '''Non-Parametric State-Space Models'''
        super().__init__()
        assert trans_prior in ('l', 'np')
        assert infer_mode in ('r', 'f')
        self.z_dim = z_dim
        self.lag = lag
        self.input_dim = input_dim
        self.lr = lr
        self.lag = lag
        self.length = length
        self.beta = beta # score for prior z regularization
        self.gamma = gamma # score for independent noise regularization
        self.correlation = correlation
        self.decoder_dist = decoder_dist
        self.infer_mode = infer_mode

        if infer_mode == 'f':
            self.net = BetaVAE_MLP_nonoise(input_dim=input_dim,
                                   z_dim=z_dim, 
                                   hidden_dim=hidden_dim)


        # Initialize transition prior
        if trans_prior == 'np':
            self.transition_prior = NPTransitionPrior(lags=lag, 
                                                      latent_size=z_dim, 
                                                      num_layers=3, 
                                                      hidden_dim=z_dim*2)
        if fixed_noise:
            self.register_buffer('xnoise', torch.tensor(0.7071))
        else:
            self.xnoise = nn.Parameter(torch.tensor(0.7071))

        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(self.z_dim))
        self.register_buffer('base_dist_var', torch.eye(self.z_dim))

    @property
    def base_dist(self):
        # Noise density function
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)

    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0
        distribution = torch.distributions.normal.Normal(x_recon, self.xnoise)
        likelihood = distribution.log_prob(x)
        recon_loss = -likelihood.sum().div(batch_size)
        return recon_loss

    def mse_loss(self, x, x_recon):
        batch_size = x.size(0)
        assert batch_size != 0
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
        return recon_loss

    def forward(self, batch):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)
        if self.infer_mode == 'f':
            _, mus, logvars, zs = self.net(x_flat)
        return zs, mus, logvars       

    def training_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        sum_log_abs_det_jacobians = 0
        x_flat = x.view(-1, self.input_dim)

        # Inference
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)

        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss_lookback = self.reconstruction_loss(x[:,:self.lag], x_recon[:,:self.lag], self.decoder_dist)
        recon_loss_forecast = (self.reconstruction_loss(x[:,self.lag:], x_recon[:,self.lag:],  self.decoder_dist))/(length-self.lag)
        recon_loss = recon_loss_lookback + recon_loss_forecast
        q_dist = D.Normal(mus, torch.exp(logvars / 2))
        log_qz = q_dist.log_prob(zs)

        # Past KLD
        p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
        log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
        log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
        kld_normal = log_qz_normal - log_pz_normal
        kld_normal = kld_normal.mean()

        # Future KLD
        log_qz_laplace = log_qz[:,self.lag:]
        residuals, logabsdet = self.transition_prior(zs)
        sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
        log_pz_laplace = torch.sum(self.base_dist.log_prob(residuals), dim=1) + sum_log_abs_det_jacobians
        kld_laplace = (torch.sum(torch.sum(log_qz_laplace,dim=-1),dim=-1) - log_pz_laplace) / (length-self.lag)
        kld_laplace = kld_laplace.mean()

        # VAE training
        loss = recon_loss + self.beta * kld_normal + self.gamma * kld_laplace
        self.log("train_elbo_loss", loss)
        return {"loss": loss, "batch_size": batch_size}

    def training_epoch_end(self, training_step_outputs):
        with torch.no_grad():
            logs_dic = {}
            for metric in training_step_outputs[0]:
                if metric != 'batch_size':
                    logs_dic[metric] = torch.tensor(0).type_as(training_step_outputs[0]['loss'])
            total_size = 0
            for log_output in training_step_outputs:
                batch_size = log_output['batch_size']
                total_size += batch_size
                for metric in logs_dic:
                    logs_dic[metric] += batch_size * log_output[metric]
            for metric in logs_dic:
                logs_dic[metric] = logs_dic[metric] / total_size
            logs_str = f"epoch {self.trainer.current_epoch}:"
            for metric in logs_dic:
                logs_str += metric+" "+str(logs_dic[metric].item())+", "
            logs_str = logs_str[:-2]
            logging.info(logs_str)

    
    def validation_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)

        # Inference
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss_lookback = self.reconstruction_loss(x[:,:self.lag], x_recon[:,:self.lag], self.decoder_dist)
        recon_loss_forecast = (self.reconstruction_loss(x[:,self.lag:], x_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        recon_loss = recon_loss_lookback + recon_loss_forecast
        q_dist = D.Normal(mus, torch.exp(logvars / 2))
        log_qz = q_dist.log_prob(zs)

        # Past KLD
        p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
        log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
        log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
        kld_normal = log_qz_normal - log_pz_normal
        kld_normal = kld_normal.mean()

        loss = recon_loss + self.beta * kld_normal

        # Compute Mean Correlation Coefficient (MCC)
        zt_recon = mus.view(-1, self.z_dim).T.detach().cpu().numpy()
        zt_true = batch["yt"].view(-1, self.z_dim).T.detach().cpu().numpy()
        mcc = compute_mcc(zt_recon, zt_true, self.correlation)

        self.log("epoch", self.trainer.current_epoch)
        self.log("val_elbo_loss", loss)
        self.log("val_mcc", mcc)
        logging.info(
            f"epoch {self.trainer.current_epoch}: val_elbo_loss {loss.item()}, "
            f"val_mcc {mcc}")
        return loss
    
    def sample(self, n=64):
        with torch.no_grad():
            e = torch.randn(n, self.z_dim, device=self.device)
            eps, _ = self.spline.inverse(e)
        return eps

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.0001)
        return [opt_v], []