"""Temporal VAE with gaussian margial and laplacian transition prior"""

import torch
import logging
import numpy as np
import torch.nn as nn
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from .components.beta import BetaVAE_MLP
from .components.transition import NPTransitionPrior
from .metrics.forecasting import (compute_mae, compute_rho)
from .components.mlp import MLPEncoder, MLPDecoder, Inference, NLayerLeakyMLP
from .metrics.correlation import compute_mcc
from .stationary_scalarnoise import StationaryProcess

import ipdb as pdb


class StationaryProcessPred2Step(pl.LightningModule):
    def __init__(
        self,
        pretrain_path,
        input_dim,
        length,
        z_dim,
        lag,
        hidden_dim=128,
        trans_prior='np',
        predict_mode='noisez',
        lr=1e-4,
        pred_lr=1e-3,
        infer_mode='f',
        alpha=1,
        beta=0.0025,
        gamma=0.0075,
        decoder_dist='gaussian',
        preidction_sample_times=50,
        correlation='Pearson'
        ):

        super().__init__()
        assert trans_prior in ('l', 'np')
        assert infer_mode in ('r', 'f')
        self.z_dim = z_dim
        self.lag = lag
        self.input_dim = input_dim
        self.lr = lr
        self.pred_lr = pred_lr
        self.lag = lag
        self.length = length
        self.beta = beta  # score for prior z regularization
        self.gamma = gamma  # score for independent noise regularization
        self.correlation = correlation
        self.decoder_dist = decoder_dist
        self.infer_mode = infer_mode
        self.predict_mode = predict_mode
        self.prediction_sample_times = preidction_sample_times
        self.criterion = nn.MSELoss()

        self.pretrain = StationaryProcess.load_from_checkpoint(checkpoint_path=pretrain_path,
                              input_dim=input_dim,
                              length=length,
                              z_dim=z_dim,
                              lag=lag,
                              hidden_dim=hidden_dim,
                              trans_prior=trans_prior,
                              lr=lr,
                              beta=beta,
                              gamma=gamma,
                              decoder_dist=decoder_dist,
                              correlation=correlation)
        self.pretrain.freeze()

        if self.predict_mode == "noisez":
            self.lstm_layer = 2
            self.predictor = nn.LSTM(input_size=self.z_dim,
                                 hidden_size=self.z_dim,
                                 num_layers=self.lstm_layer,
                                 batch_first=True,
                                 bias=True
                                 )
            self.mixnoise = NLayerLeakyMLP(in_features= 2 *z_dim,
                             out_features=z_dim,
                             num_layers=2,
                             hidden_dim=z_dim)
            self.init_hidden = nn.Parameter(torch.zeros(self.lstm_layer, self.z_dim))
            self.init_cell = nn.Parameter(torch.zeros(self.lstm_layer, self.z_dim))

            # base distribution for calculation of log prob under the model
            self.register_buffer('base_dist_mean', torch.zeros(self.z_dim))
            self.register_buffer('base_dist_var', torch.eye(self.z_dim))


    @property
    def base_dist(self):
        # Noise density function
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)

    def training_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.pretrain.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)

        # Prediction
        loss = torch.zeros(1).type_as(x)
        hidden = self.init_hidden.unsqueeze(1).repeat(1, batch_size, 1)
        cell = self.init_cell.unsqueeze(1).repeat(1, batch_size, 1)

        predictions_z = torch.zeros(batch_size, self.length, self.z_dim).type_as(x)
        if self.predict_mode == 'noisez':
            residuals, logabsdet = self.pretrain.transition_prior(zs)
            residuals_input = torch.cat(
                (torch.zeros(batch_size, length - residuals.shape[1], residuals.shape[2]).type_as(x), residuals), dim=1)
            zs_input = torch.cat((torch.zeros(batch_size, 1, self.z_dim).type_as(x), zs[:, :-1, :]), dim=1)
            for i in range(length - self.length):
                _, (hidden, cell) = self.predictor(zs_input[:, i, :].unsqueeze(1), (hidden, cell))
            for i in range(self.length):
                zs_previous = zs_input[:, length - self.length + i, :]
                prediction, (hidden, cell) = self.predictor(zs_previous.unsqueeze(1), (hidden, cell))
                # zs_previous = prediction
                prediction_input = torch.cat([prediction, residuals_input[:, i + length - self.length: i + length - self.length + 1, :]], dim=1).view(batch_size, -1)
                prediction_noise = self.mixnoise(prediction_input)
                predictions_z[:, i, :] = prediction_noise
                loss += self.criterion(prediction_noise, zs[:, i + length - self.length, :])

        loss = loss[0]
        prediction_x = self.pretrain.net.decoder(predictions_z)
        mae = compute_mae(prediction_x, x[:, -self.length:, :])
        self.log("train_loss", loss)
        self.log("train_prediction_mae", mae.mean())

        return {"loss": loss, "batch_size": batch_size, "train_prediction_mae":mae.mean()}

    def training_epoch_end(self, training_step_outputs):
        with torch.no_grad():
            logs_dic = {}

            for metric in training_step_outputs[0]:
                if metric != 'batch_size':
                    logs_dic[metric] = torch.tensor(0).type_as(training_step_outputs[0]['loss'])
            total_size = 0

            for log_output in training_step_outputs:
                batch_size = log_output['batch_size']
                total_size += batch_size
                for metric in logs_dic:
                    logs_dic[metric] += batch_size * log_output[metric]

            for metric in logs_dic:
                logs_dic[metric] = logs_dic[metric] / total_size

            logs_str = f"epoch {self.trainer.current_epoch}:"
            for metric in logs_dic:
                logs_str += metric+" "+str(logs_dic[metric].item())+", "
            logs_str = logs_str[:-2]
            logging.info(logs_str)

    
    def validation_step(self, batch, batch_idx):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)

        # Inference
        if self.infer_mode == 'f':
            x_recon, mus, logvars, zs = self.pretrain.net(x_flat)

        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)

        # Prediction
        loss = torch.zeros(1).type_as(x)
        hidden = self.init_hidden.unsqueeze(1).repeat(1, batch_size, 1)
        cell = self.init_cell.unsqueeze(1).repeat(1, batch_size, 1)
        predictions_z = torch.zeros(self.prediction_sample_times, batch_size, self.length, self.z_dim).type_as(x)

        if self.predict_mode == 'noisez':
            residuals = self.base_dist.sample([self.prediction_sample_times, batch_size, length])
            zs_input = torch.cat((torch.zeros(batch_size, 1, self.z_dim).type_as(x), zs[:, :-1, :]), dim=1)
            for i in range(length - self.length):
                _, (hidden, cell) = self.predictor(zs_input[:, i, :].unsqueeze(1), (hidden, cell))
            hidden_previous = hidden.clone().detach()
            cell_previous = cell.clone().detach()
            for j in range(self.prediction_sample_times):
                # obtain the initial hidden and cell
                zs_previous = zs_input[:, length - self.length, :]
                hidden = hidden_previous
                cell = cell_previous
                for i in range(self.length):
                    prediction, (hidden, cell) = self.predictor(zs_previous.unsqueeze(1), (hidden, cell))
                    prediction_input = torch.cat(
                        [prediction, residuals[j, :, i + length - self.length: i + length - self.length + 1]],
                        dim=1).view(batch_size, -1)
                    prediction_noise = self.mixnoise(prediction_input)
                    predictions_z[j, :, i, :] = prediction_noise
                    zs_previous = prediction_noise
                    loss += self.criterion(prediction_noise, zs[:, i + self.lag, :])

        loss = loss[0]
        predictions_z_median = predictions_z.median(dim=0).values
        prediction_x = self.pretrain.net.decoder(predictions_z)
        prediction_x_median = prediction_x.median(dim=0).values
        mae = compute_mae(prediction_x_median, x[:, -self.length:, :])
        rho50 = compute_rho(0.5, prediction_x, x[:, -self.length:, :])
        rho90 = compute_rho(0.9, prediction_x, x[:, -self.length:, :])


        self.log("epoch", self.trainer.current_epoch)
        self.log("val_loss", loss)

        self.log("val_predict_mae", mae.mean())
        self.log("val_predict_rho50", rho50[0] / rho50[1])
        self.log("val_predict_rho90", rho90[0] / rho90[1])


        logging.info(
            f"epoch {self.trainer.current_epoch}: val_loss {loss.item()}, "
            f"val_predict_mae {mae.mean().item()}, val_predict_rho50 {rho50[0] / rho50[1]}, val_predict_rho90 {rho90[0] / rho90[1]}"
            )
        return loss
    
    def sample(self, n=64):
        with torch.no_grad():
            e = torch.randn(n, self.z_dim, device=self.device)
            eps, _ = self.spline.inverse(e)
        return eps

    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def configure_optimizers(self):
        opt_v = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.pred_lr, betas=(0.9, 0.999), weight_decay=0.0001)
        return [opt_v], []