"""N-NPSSM"""

import torch
import logging
import torch.nn as nn
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from .components.beta import BetaVAE_MLP_nonoise
from .components.mlp import NLayerLeakyMLP
from .metrics.forecasting import (compute_mae, compute_rho)
from .components.transition import (
                                    NPChangeTransitionPrior,
                                    NPTransitionPrior)


class NonStationaryPredProcess(pl.LightningModule):

    def __init__(
        self,
        input_dim,
        length,
        z_dim,
        lag,
        hidden_dim=128,
        trans_prior='np',
        lr=1e-4,
        infer_mode='f',
        alpha=1,
        beta=0.0025,
        gamma=0.0075,
        delta=0.0025,
        epsilon=0.0075,
        a_distribution='gaussian',
        decoder_dist='gaussian',
        prediction_sample_times=50,
        predict_mode='noisez',
        predict_witha=False,
        lstm_layer=2,
    ):
        super().__init__()
        assert trans_prior in ('l', 'np')
        assert infer_mode in ('r', 'f')
        self.z_dim = z_dim
        self.a_dim = int(self.z_dim / 2)
        self.lag = lag
        self.input_dim = input_dim
        self.lr = lr
        self.length = length # elbo
        self.beta = beta # score for prior z regularization
        self.gamma = gamma # score for independent noise regularization
        self.delta = delta # score for prior a regularization
        self.epsilon = epsilon # score for independent noise for a level regularization
        self.alpha = alpha # score for predictor
        self.decoder_dist = decoder_dist
        self.a_distribution = a_distribution
        self.infer_mode = infer_mode

        self.predict_mode = predict_mode
        self.prediction_sample_times = prediction_sample_times
        self.predict_witha = predict_witha
        self.criterion = nn.MSELoss()

        # MLP inference network
        self.net = BetaVAE_MLP_nonoise(input_dim=input_dim,
                                   z_dim=z_dim,
                                   hidden_dim=hidden_dim)

        # Initialize transition prior
        if trans_prior == 'np':
            self.transition_prior_a = NPTransitionPrior(lags=1,
                                                        latent_size=self.a_dim,
                                                        num_layers=2,
                                                        hidden_dim=z_dim)

            self.transition_prior = NPChangeTransitionPrior(lags=lag,
                                                      latent_size=z_dim,
                                                      embedding_dim=self.a_dim,
                                                      num_layers=3,
                                                      hidden_dim=z_dim*2)

        # Infer a from z
        self.lag_a = 1
        self.mlp_z_a = NLayerLeakyMLP(in_features=self.z_dim * (self.lag+1),
                                      out_features=hidden_dim,
                                      num_layers=2,
                                      hidden_dim=hidden_dim)
        self.a_mean = nn.Linear(hidden_dim, self.a_dim)
        self.a_logvar = nn.Linear(hidden_dim, self.a_dim)

        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(self.z_dim))
        self.register_buffer('base_dist_var', torch.eye(self.z_dim))
        if self.predict_mode == "noisez":
            self.lstm_layer = lstm_layer
            self.predictor = nn.LSTM(input_size=self.z_dim,
                                 hidden_size=self.z_dim,
                                 num_layers=self.lstm_layer,
                                 batch_first=True,
                                 bias=True
                                 )
            self.mixnoise = NLayerLeakyMLP(in_features= 2 *z_dim + self.a_dim if self.predict_witha else 2 *z_dim,
                             out_features=z_dim,
                             num_layers=2,
                             hidden_dim=z_dim)
            self.init_hidden = nn.Parameter(torch.zeros(self.lstm_layer, self.z_dim))
            self.init_cell = nn.Parameter(torch.zeros(self.lstm_layer, self.z_dim))

    @property
    def base_dist(self):
        # Noise density function
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)

    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0
        if distribution == 'gaussian':
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
        return recon_loss

    def mse_loss(self, x, x_recon):
        batch_size = x.size(0)
        assert batch_size != 0
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
        return recon_loss

    def forward(self, batch):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)
        if self.infer_mode == 'f':
            _, mus, logvars, zs = self.net(x_flat)
        return zs, mus, logvars

    def inference_a(self, zs):
        zs_feature = self.mlp_z_a(zs)
        a_mean = self.a_mean(zs_feature)
        a_logvar = self.a_logvar(zs_feature)
        a = self.reparameterize(a_mean, a_logvar)
        return a, a_mean, a_logvar

    def training_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        sum_log_abs_det_jacobians = 0
        sum_log_abs_det_jacobians_a = 0
        x_flat = x.view(-1, self.input_dim)

        # Inference
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)
        zs_tuple = mus.unfold(dimension = 1, size = self.lag+1, step = 1)
        zs_tuple = zs_tuple.reshape(batch_size, length - self.lag, -1)

        # infer a
        a, a_mean, a_logvar = self.inference_a(zs_tuple)
        a_input = torch.cat(
            (torch.zeros(batch_size, length - a[:, :-1, :].shape[1], self.a_dim).type_as(x), a[:, :-1, :]), dim=1)

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss_lookback = self.reconstruction_loss(x[:, :self.lag], x_recon[:, :self.lag], self.decoder_dist)
        recon_loss_forecast = (self.reconstruction_loss(x[:, self.lag:], x_recon[:, self.lag:], self.decoder_dist)) / (
                    length - self.lag)
        recon_loss = recon_loss_lookback + recon_loss_forecast
        q_dist = D.Normal(mus, torch.exp(logvars / 2))
        log_qz = q_dist.log_prob(zs)

        # Past KLD
        p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
        log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
        log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
        kld_normal = log_qz_normal - log_pz_normal
        kld_normal = kld_normal.mean()

        # Future KLD
        log_qz_laplace = log_qz[:,self.lag:]
        residuals, logabsdet = self.transition_prior(zs, a)
        sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
        log_pz_laplace = torch.sum(self.base_dist.log_prob(residuals), dim=1) + sum_log_abs_det_jacobians
        kld_laplace = (torch.sum(torch.sum(log_qz_laplace,dim=-1),dim=-1) - log_pz_laplace) / (length-self.lag)
        kld_laplace = kld_laplace.mean()

        ######################
        if self.a_distribution == 'gaussian':
            qa_dist = D.Normal(a_mean, torch.exp(a_logvar / 2))
            self.base_dist_a = D.MultivariateNormal(torch.zeros(self.a_dim).type_as(x),
                                                    torch.eye(self.a_dim).type_as(x))
        log_qa = qa_dist.log_prob(a)

        # Past a
        if self.a_distribution == 'gaussian':
            p_dist = D.Normal(torch.zeros_like(a_mean[:, :1]), torch.ones_like(a_logvar[:, :1]))
        log_pa_normal = torch.sum(torch.sum(p_dist.log_prob(a[:, :1]), dim=-1), dim=-1)
        log_qa_normal = torch.sum(torch.sum(log_qa[:, :1], dim=-1), dim=-1)
        kld_normal_a = log_qa_normal - log_pa_normal
        kld_normal_a = kld_normal_a.mean()

        # # Future a
        log_qa_laplace = log_qa[:, 1:]
        if log_qa_laplace.numel() == 0:
            kld_laplace_a = torch.tensor(0.0).type_as(kld_normal_a)
        else:
            residuals_a, logabsdet_a = self.transition_prior_a(a)
            sum_log_abs_det_jacobians_a = sum_log_abs_det_jacobians_a + logabsdet_a
            log_pa_laplace = torch.sum(self.base_dist_a.log_prob(residuals_a), dim=1) + sum_log_abs_det_jacobians_a
            kld_laplace_a = (torch.sum(torch.sum(log_qa_laplace, dim=-1), dim=-1) - log_pa_laplace) / (
                        length - self.lag)
            kld_laplace_a = kld_laplace_a.mean()

        # Prediction
        pred_loss = torch.zeros(1).type_as(x)
        hidden = self.init_hidden.unsqueeze(1).repeat(1, batch_size,
                                                      1)
        cell = self.init_cell.unsqueeze(1).repeat(1, batch_size,
                                                  1)
        predictions_z = torch.zeros(batch_size, self.length, self.z_dim).type_as(x)
        if self.predict_mode == 'noisez':
            residuals, logabsdet = self.transition_prior(zs, a)
            residuals_input = torch.cat(
                (torch.zeros(batch_size, length - residuals.shape[1], residuals.shape[2]).type_as(x), residuals), dim=1)
            zs_input = torch.cat((torch.zeros(batch_size, 1, self.z_dim).type_as(x), zs[:, :-1, :]), dim=1)
            for i in range(length - self.length):
                _, (hidden, cell) = self.predictor(zs_input[:, i, :].unsqueeze(1), (hidden, cell))
            for i in range(self.length):
                zs_previous = zs_input[:, length - self.length + i, :]
                prediction, (hidden, cell) = self.predictor(zs_previous.unsqueeze(1), (hidden, cell))
                prediction_input = torch.cat([prediction, residuals_input[:, i + length - self.length: i + length - self.length + 1, :]], dim=1).view(batch_size, -1)
                if self.predict_witha:
                    prediction_input = torch.cat([prediction_input, a_input[:, i + length - self.length]], dim=1)
                prediction_noise = self.mixnoise(prediction_input)
                predictions_z[:, i, :] = prediction_noise
                pred_loss += self.criterion(prediction_noise, zs[:, i + length - self.length, :])
        pred_loss = pred_loss[0]

        # VAE training
        loss = recon_loss + self.beta * kld_normal + self.gamma * kld_laplace + self.delta * kld_normal_a + self.epsilon * kld_laplace_a + self.alpha*pred_loss
        self.log("train_elbo_loss", loss)
        return {"loss": loss, "batch_size": batch_size}

    def training_epoch_end(self, training_step_outputs):
        with torch.no_grad():
            logs_dic = {}
            for metric in training_step_outputs[0]:
                if metric != 'batch_size':
                    logs_dic[metric] = torch.tensor(0).type_as(training_step_outputs[0]['loss'])
            total_size = 0
            for log_output in training_step_outputs:
                batch_size = log_output['batch_size']
                total_size += batch_size
                for metric in logs_dic:
                    logs_dic[metric] += batch_size * log_output[metric]
            for metric in logs_dic:
                logs_dic[metric] = logs_dic[metric] / total_size
            logs_str = f"epoch {self.trainer.current_epoch}:"
            for metric in logs_dic:
                logs_str += metric+" "+str(logs_dic[metric].item())+", "
            logs_str = logs_str[:-2]
            logging.info(logs_str)


    def validation_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)

        # Inference
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)
        zs_tuple = mus.unfold(dimension = 1, size = self.lag+1, step = 1)
        zs_tuple = zs_tuple.reshape(batch_size, length - self.lag, -1)

        # infer a
        a, a_mean, a_logvar = self.inference_a(zs_tuple)
        a_input = torch.cat(
            (torch.zeros(batch_size, length - a[:, :-1, :].shape[1], self.a_dim).type_as(x), a[:, :-1, :]), dim=1)

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss_lookback = self.reconstruction_loss(x[:,:self.lag], x_recon[:,:self.lag], self.decoder_dist)
        recon_loss_forecast = (self.reconstruction_loss(x[:,self.lag:], x_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        recon_loss = recon_loss_lookback + recon_loss_forecast
        q_dist = D.Normal(mus, torch.exp(logvars / 2))
        log_qz = q_dist.log_prob(zs)

        # Past KLD
        p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
        log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
        log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
        kld_normal = log_qz_normal - log_pz_normal
        kld_normal = kld_normal.mean()

        if self.a_distribution == 'gaussian':
            qa_dist = D.Normal(a_mean, torch.exp(a_logvar / 2))
            log_qa = qa_dist.log_prob(a)

        # Past a
        if self.a_distribution == 'gaussian':
            p_dist = D.Normal(torch.zeros_like(a_mean[:, (self.lag-1):self.lag]), torch.ones_like(a_logvar[:, (self.lag-1):self.lag]))
        log_pa_normal = torch.sum(torch.sum(p_dist.log_prob(a[:, :1]), dim=-1), dim=-1)
        log_qa_normal = torch.sum(torch.sum(log_qa[:,:1], dim=-1), dim=-1)
        kld_normal_a = log_qa_normal - log_pa_normal
        kld_normal_a = kld_normal_a.mean()

        # Prediction
        pred_loss = torch.zeros(1).type_as(x)
        hidden = self.init_hidden.unsqueeze(1).repeat(1, batch_size,
                                                      1)
        cell = self.init_cell.unsqueeze(1).repeat(1, batch_size,
                                                  1)
        predictions_z = torch.zeros(self.prediction_sample_times, batch_size, self.length, self.z_dim).type_as(x)

        if self.predict_mode == 'noisez':
            residuals = self.base_dist.sample([self.prediction_sample_times, batch_size, length])
            zs_input = torch.cat((torch.zeros(batch_size, 1, self.z_dim).type_as(x), zs[:, :-1, :]), dim=1)
            for i in range(length - self.length):
                _, (hidden, cell) = self.predictor(zs_input[:, i, :].unsqueeze(1), (hidden, cell))
            hidden_previous = hidden.clone().detach()
            cell_previous = cell.clone().detach()
            for j in range(self.prediction_sample_times):
                # obtain the initial hidden and cell
                zs_previous = zs_input[:, length - self.length, :]
                hidden = hidden_previous
                cell = cell_previous
                for i in range(self.length):
                    prediction, (hidden, cell) = self.predictor(zs_previous.unsqueeze(1), (hidden, cell))
                    prediction_input = torch.cat(
                        [prediction, residuals[j, :, i + length - self.length: i + length - self.length + 1]],
                        dim=1).view(batch_size, -1)
                    if self.predict_witha:
                        prediction_input = torch.cat([prediction_input, a_input[:, i + length - self.length]], dim=1)
                    prediction_noise = self.mixnoise(prediction_input)
                    predictions_z[j, :, i, :] = prediction_noise
                    zs_previous = prediction_noise
                    pred_loss += self.criterion(prediction_noise, zs[:, i + length - self.length, :])
        pred_loss = pred_loss[0]
        prediction_x = self.net.decoder(predictions_z)
        prediction_x_median = prediction_x.median(dim=0).values
        mae = compute_mae(prediction_x_median, x[:, -self.length:, :])
        rho50 = compute_rho(0.5, prediction_x, x[:, -self.length:, :])
        rho90 = compute_rho(0.9, prediction_x, x[:, -self.length:, :])

        # VAE training
        loss = recon_loss + self.beta * kld_normal + self.delta * kld_normal_a + self.alpha * pred_loss
        self.log("epoch", self.trainer.current_epoch)
        self.log("val_elbo_loss", loss)

        self.log("val_predict_mae", mae.mean())
        self.log("val_predict_rho50", rho50[0] / rho50[1])
        self.log("val_predict_rho90", rho90[0] / rho90[1])

        logging.info(
            f"epoch {self.trainer.current_epoch}: val_elbo_loss {loss.item()}, "
            f"val_predict_mae {mae.mean().item()}, val_predict_rho50 {rho50[0] / rho50[1]}, val_predict_rho90 {rho90[0] / rho90[1]}"
            )
        return loss

    def sample(self, n=64):
        with torch.no_grad():
            e = torch.randn(n, self.z_dim, device=self.device)
            eps, _ = self.spline.inverse(e)
        return eps

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.0001)
        return [opt_v], []