"""Temporal VAE with gaussian margial and laplacian transition prior"""

import torch
import numpy as np
import torch.nn as nn
import lightning.pytorch as pl
import torch.distributions as D
from torch.nn import functional as F
from .components.beta import BetaVAE_MLP
from .components.transition import (MBDTransitionPrior, 
                                    NPTransitionPrior,
                                    NPDTransitionPrior)
from .components.mlp import MLPEncoder, MLPDecoder, Inference, GaussianNet, NLayerLeakyMLP
from .components.mine import MINE
from .metrics.correlation import compute_r2
from IFactor.tools.utils import get_parameters

import ipdb as pdb

class StationaryIFactorProcess(pl.LightningModule):

    def __init__(
        self, 
        input_dim,
        length,
        z_dim_list,
        action_dim,
        lag,
        config,
        z_dim_true_list=[2, 2, 2, 2],
        hidden_dim=128,
        trans_prior='DNP',
        lr=1e-4,
        aux_lr=1e-4,
        infer_mode='F',
        beta=0.0025,
        gamma=0.0075,
        delta=0.01,
        delta_epoch=10,
        decoder_dist='gaussian',
        correlation='Pearson'):
        '''Nonlinear ICA for nonparametric stationary processes'''
        super().__init__()
        self.automatic_optimization=False
        # Transition prior must be L (Linear), NP (Nonparametric)
        assert trans_prior in ('L', 'NP', 'DNP', 'NONE')
        self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim = z_dim_list[0], z_dim_list[1], z_dim_list[2], z_dim_list[3]
        z_dim = sum(z_dim_list)
        self.z_dim = z_dim
        self.z_dim_true_list = z_dim_true_list
        self.action_dim = action_dim
        self.lag = lag
        self.input_dim = input_dim
        self.lr = lr
        self.aux_lr = aux_lr
        self.lag = lag
        self.length = length
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.delta_epoch = delta_epoch
        self.correlation = correlation
        self.decoder_dist = decoder_dist
        self.infer_mode = infer_mode
        # Recurrent/Factorized inference
        self.representation_model_list = []
        if infer_mode == 'R':
            self.enc = MLPEncoder(latent_size=z_dim, 
                                  num_layers=3, 
                                  hidden_dim=hidden_dim)

            self.dec = MLPDecoder(latent_size=z_dim, 
                                  num_layers=2,
                                  hidden_dim=hidden_dim)
            
            # Bi-directional hidden state rnn
            self.rnn = nn.GRU(input_size=z_dim, 
                              hidden_size=hidden_dim, 
                              num_layers=1, 
                              batch_first=True, 
                              bidirectional=True)
            
            # Inference net
            self.net = Inference(lag=lag,
                                 z_dim=z_dim, 
                                 hidden_dim=hidden_dim, 
                                 num_layers=2)
            self.representation_model_list += [self.enc, self.dec, self.rnn, self.net]

        elif infer_mode == 'F':
            self.net = BetaVAE_MLP(input_dim=input_dim, 
                                   z_dim=z_dim, 
                                   hidden_dim=hidden_dim)
            self.representation_model_list += [self.net]

        # Initialize transition prior
        if trans_prior == 'L':
            self.transition_prior = MBDTransitionPrior(lags=lag, 
                                                       latent_size=z_dim, 
                                                       bias=False)
        elif trans_prior == 'DNP':
            self.transition_prior = NPDTransitionPrior(lags=lag, 
                                                      latent_size_list=[self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim], 
                                                      action_size=action_dim,
                                                      num_layers=3, 
                                                      hidden_dim=hidden_dim)
        elif trans_prior == 'NP':
            self.transition_prior = NPTransitionPrior(lags=lag, 
                                                      latent_size=self.z_dim, 
                                                      action_size=action_dim,
                                                      num_layers=3, 
                                                      hidden_dim=hidden_dim)
        else:
            self.transition_prior = None

        self.rew_dec = NLayerLeakyMLP(in_features=self.z1_dim+self.z2_dim, out_features=1, num_layers=2)
        if self.transition_prior  is None:
            self.representation_model_list += [self.rew_dec]
        else:
            self.representation_model_list += [self.transition_prior, self.rew_dec]
        # I(s_t^{1, 2}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{3, 4}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{1, 3}; a_{t-1}|s_{t-1})
        # I(s_t^{2, 4}; a_{t-1} \,|s_{t-1})
        self.mine_reward_1 = MINE(x_dim=self.z1_dim + self.z2_dim + self.z1_dim + self.z2_dim + action_dim, y_dim=1)
        self.mine_reward_2 = MINE(x_dim=self.z3_dim + self.z4_dim + self.z1_dim + self.z2_dim + action_dim, y_dim=1)
        self.mine_action_1 = MINE(x_dim=self.z1_dim + self.z3_dim + self.z_dim, y_dim=action_dim)
        self.mine_action_2 = MINE(x_dim=self.z2_dim + self.z4_dim + self.z_dim, y_dim=action_dim)
        self.aux_model_list = [self.mine_reward_1, self.mine_reward_2, self.mine_action_1, self.mine_action_2]
        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(self.z_dim))
        self.register_buffer('base_dist_var', torch.eye(self.z_dim))

    @property
    def base_dist(self):
        # Noise density function
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)

    def inference(self, ft, random_sampling=True):
        ## bidirectional lstm/gru 
        # input: (batch, seq_len, z_dim)
        # output: (batch, seq_len, z_dim)
        output, h_n = self.rnn(ft)
        batch_size, length, _ = output.shape
        # beta, hidden = self.gru(ft, hidden)
        ## sequential sampling & reparametrization
        ## transition: p(zt|z_tau)
        zs, mus, logvars = [], [], []
        for tau in range(self.lag):
            zs.append(torch.ones((batch_size, self.z_dim), device=output.device))

        for t in range(length):
            mid = torch.cat(zs[-self.lag:], dim=1)
            inputs = torch.cat([mid, output[:,t,:]], dim=1)
            distributions = self.net(inputs)
            mu = distributions[:, :self.z_dim]
            logvar = distributions[:, self.z_dim:]
            zt = self.reparameterize(mu, logvar, random_sampling)
            zs.append(zt)
            mus.append(mu)
            logvars.append(logvar)

        zs = torch.squeeze(torch.stack(zs, dim=1))
        # Strip the first L zero-initialized zt 
        zs = zs[:,self.lag:]
        mus = torch.squeeze(torch.stack(mus, dim=1))
        logvars = torch.squeeze(torch.stack(logvars, dim=1))
        return zs, mus, logvars

    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0

        if distribution == 'bernoulli':
            recon_loss = F.binary_cross_entropy_with_logits(
                x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'gaussian':
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'sigmoid_gaussian':
            x_recon = F.sigmoid(x_recon)
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        return recon_loss
    
    def aux_reward_loss(self, distribution, y):
        mle_loss = -distribution.log_prob(y).mean()
        mse_loss = F.mse_loss(distribution.mean, y)
        return mle_loss, mse_loss

    def forward(self, batch):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)
        if self.infer_mode == 'R':
            ft = self.enc(x_flat)
            ft = ft.view(batch_size, length, -1)
            zs, mus, logvars = self.inference(ft, random_sampling=True)
        elif self.infer_mode == 'F':
            _, mus, logvars, zs = self.net(x_flat)
        return zs, mus, logvars

    def training_step(self, batch, batch_idx):
        x, y, r, a = batch['xt'], batch['yt'], batch['rt'], batch['at']
        batch_size, length, _ = x.shape
        sum_log_abs_det_jacobians = 0
        x_flat = x.view(-1, self.input_dim)
        # Inference
        if self.infer_mode == 'R':
            ft = self.enc(x_flat)
            ft = ft.view(batch_size, length, -1)
            zs, mus, logvars = self.inference(ft)
            zs_flat = zs.contiguous().view(-1, self.z_dim)
            x_recon = self.dec(zs_flat)
        elif self.infer_mode == 'F':
            x_recon, mus, logvars, zs = self.net(x_flat)
            zs_flat = zs.contiguous().view(-1, self.z_dim)
        z1_flat, z2_flat, z3_flat, z4_flat = torch.split(zs_flat, [self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim], dim=-1)
        r_recon = self.rew_dec(torch.concat([z1_flat, z2_flat], dim=-1))
        
        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        r_recon = r_recon.view(batch_size, length, 1)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)
        z1, z2, z3, z4 = torch.split(zs, [self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim], dim=-1)
        
        z1_t_1, z2_t_1, z3_t_1, z4_t_1 = z1[:, :-2, :], z2[:, :-2, :], z3[:, :-2, :], z4[:, :-2, :]
        z1_t, z2_t, z3_t, z4_t = z1[:, 1:-1, :], z2[:, 1:-1, :], z3[:, 1:-1, :], z4[:, 1:-1, :]
        a_t_1, a_t = a[:, :-2, :], a[:, 1:-1, :]
        r_t1 = r[:, 2:, :]
        
        z1_t_1, z2_t_1, z3_t_1, z4_t_1 = z1[:, :-1, :], z2[:, :-1, :], z3[:, :-1, :], z4[:, :-1, :]
        z1_t, z2_t, z3_t, z4_t = z1[:, 1:, :], z2[:, 1:, :], z3[:, 1:, :], z4[:, 1:, :]
        a_t_1, a_t = a[:, :-1, :], a[:, 1:, :]
        r_t = r[:, 1:, :]
        
        # I(s_t^{1, 2}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{3, 4}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{1, 3}; a_{t-1}|s_{t-1})
        # I(s_t^{2, 4}; a_{t-1} \,|s_{t-1})
        
        mine_reward_1_input = (torch.cat([z1_t, z2_t, a_t_1, z1_t_1.detach(), z2_t_1.detach()], dim=-1), r_t)
        mine_reward_2_input = (torch.cat([z3_t, z4_t, a_t_1, z1_t_1.detach(), z2_t_1.detach()], dim=-1), r_t)
        mine_action_1_input = (torch.cat([z1_t, z3_t, z1_t_1.detach(), z2_t_1.detach(), z3_t_1.detach(), z4_t_1.detach()], dim=-1), a_t_1)
        mine_action_2_input = (torch.cat([z2_t, z4_t, z1_t_1.detach(), z2_t_1.detach(), z3_t_1.detach(), z4_t_1.detach()], dim=-1), a_t_1)


        mine_inner_loss1 = - self.mine_reward_1(mine_reward_1_input[0].detach(), mine_reward_1_input[1].detach())
        mine_inner_loss2 = - self.mine_reward_2(mine_reward_2_input[0].detach(), mine_reward_2_input[1].detach())
        mine_inner_loss3 = - self.mine_action_1(mine_action_1_input[0].detach(), mine_action_2_input[1].detach())
        mine_inner_loss4 = - self.mine_action_2(mine_action_2_input[0].detach(), mine_action_2_input[1].detach())


        mine_outer_loss_reward = - (torch.clamp(self.mine_reward_1(mine_reward_1_input[0], mine_reward_1_input[1]), min=0) - torch.clamp(self.mine_reward_2(mine_reward_2_input[0], mine_reward_2_input[1]), min=0))
        
        mine_outer_loss_action = - (torch.clamp(self.mine_action_1(mine_action_1_input[0], mine_action_1_input[1]), min=0) - torch.clamp(self.mine_action_2(mine_action_2_input[0], mine_action_2_input[1]), min=0))

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss = self.reconstruction_loss(x[:,:self.lag], x_recon[:,:self.lag], self.decoder_dist) + \
        (self.reconstruction_loss(x[:,self.lag:], x_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        recon_reward_loss = self.reconstruction_loss(r[:,:self.lag], r_recon[:,:self.lag], self.decoder_dist) + \
        (self.reconstruction_loss(r[:,self.lag:], r_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        if self.transition_prior is not None:
            q_dist = D.Normal(mus, torch.exp(logvars / 2))
            log_qz = q_dist.log_prob(zs)
            # Past KLD
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
            log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
            log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
            kld_normal = log_qz_normal - log_pz_normal
            kld_normal = kld_normal.mean()
            # Future KLD
            log_qz_laplace = log_qz[:,self.lag:]
            residuals, logabsdet = self.transition_prior(zs, a)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
            log_pz_laplace = torch.sum(self.base_dist.log_prob(residuals), dim=1) + sum_log_abs_det_jacobians
            kld_laplace = (torch.sum(torch.sum(log_qz_laplace,dim=-1),dim=-1) - log_pz_laplace) / (length-self.lag)
            kld_laplace = kld_laplace.mean()
        else:
            q_dist = D.Normal(mus[:, :self.lag], torch.exp(logvars[:, :self.lag] / 2))
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
            kld_normal = torch.mean(torch.sum(torch.sum(torch.distributions.kl_divergence(q_dist, p_dist), dim=-1), dim=-1))
            # Future KLD
            q_dist_future = D.Normal(mus[:, self.lag:], torch.exp(logvars / 2)[:, self.lag:])
            p_dist_future = D.Normal(torch.zeros_like(mus[:, self.lag:]), torch.ones_like(logvars[:,self.lag:]))
            kld_laplace = torch.mean(torch.sum(torch.sum(torch.distributions.kl_divergence(q_dist_future, p_dist_future), dim=-1), dim=-1)) / (length - self.lag)

        # VAE training
        if self.trainer.current_epoch < self.delta_epoch:
            loss = recon_loss + recon_reward_loss  + self.beta * kld_normal + self.gamma * kld_laplace
        else:
            loss = recon_loss + recon_reward_loss  + self.beta * kld_normal + self.gamma * kld_laplace + self.delta * mine_outer_loss_reward + self.delta * mine_outer_loss_action
        loss_aux = mine_inner_loss1 + mine_inner_loss2 + mine_inner_loss3 + mine_inner_loss4
        opt_r, opt_aux = self.optimizers()
        opt_r.zero_grad()
        self.manual_backward(loss)
        self.clip_gradients(opt_r, gradient_clip_val=5, gradient_clip_algorithm="norm")
        opt_r.step()
        opt_aux.zero_grad()
        self.manual_backward(loss_aux)
        self.clip_gradients(opt_aux, gradient_clip_val=5, gradient_clip_algorithm="norm")
        opt_aux.step()

        self.log("train_elbo_loss", loss)
        self.log("train_recon_obs_loss", recon_loss)
        self.log("train_recon_reward_loss", recon_reward_loss)
        self.log("mine_inner_loss1", mine_inner_loss1)
        self.log("mine_inner_loss2", mine_inner_loss2)
        self.log("mine_inner_loss3", mine_inner_loss3)
        self.log("mine_inner_loss4", mine_inner_loss4)
        self.log("mine_outer_loss_reward", mine_outer_loss_reward)
        self.log("mine_outer_loss_action", mine_outer_loss_action)
        self.log("train_kld_normal", kld_normal)
        self.log("train_kld_laplace", kld_laplace)
        return {'opt_r': loss, 'opt_aux': loss_aux}
    
    def validation_step(self, batch, batch_idx):
        x, y, r, a = batch['xt'], batch['yt'], batch['rt'], batch['at']
        batch_size, length, _ = x.shape
        sum_log_abs_det_jacobians = 0
        x_flat = x.view(-1, self.input_dim)
        # Inference
        if self.infer_mode == 'R':
            ft = self.enc(x_flat)
            ft = ft.view(batch_size, length, -1)
            zs, mus, logvars = self.inference(ft)
            zs_flat = zs.contiguous().view(-1, self.z_dim)
            x_recon = self.dec(zs_flat)
        elif self.infer_mode == 'F':
            x_recon, mus, logvars, zs = self.net(x_flat)
            zs_flat = zs.contiguous().view(-1, self.z_dim)
        z1_flat, z2_flat, z3_flat, z4_flat = torch.split(zs_flat, [self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim], dim=-1)
        r_recon = self.rew_dec(torch.concat([z1_flat, z2_flat], dim=-1))
        
        # Reshape to time-series format
        x_recon = x_recon.view(batch_size, length, self.input_dim)
        r_recon = r_recon.view(batch_size, length, 1)
        mus = mus.reshape(batch_size, length, self.z_dim)
        logvars  = logvars.reshape(batch_size, length, self.z_dim)
        zs = zs.reshape(batch_size, length, self.z_dim)
        z1, z2, z3, z4 = torch.split(zs, [self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim], dim=-1)
        
        z1_t_1, z2_t_1, z3_t_1, z4_t_1 = z1[:, :-2, :], z2[:, :-2, :], z3[:, :-2, :], z4[:, :-2, :]
        z1_t, z2_t, z3_t, z4_t = z1[:, 1:-1, :], z2[:, 1:-1, :], z3[:, 1:-1, :], z4[:, 1:-1, :]
        a_t_1, a_t = a[:, :-2, :], a[:, 1:-1, :]
        r_t1 = r[:, 2:, :]
        
        z1_t_1, z2_t_1, z3_t_1, z4_t_1 = z1[:, :-1, :], z2[:, :-1, :], z3[:, :-1, :], z4[:, :-1, :]
        z1_t, z2_t, z3_t, z4_t = z1[:, 1:, :], z2[:, 1:, :], z3[:, 1:, :], z4[:, 1:, :]
        a_t_1, a_t = a[:, :-1, :], a[:, 1:, :]
        r_t = r[:, 1:, :]
        
        # I(s_t^{1, 2}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{3, 4}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{1, 3}; a_{t-1}|s_{t-1})
        # I(s_t^{2, 4}; a_{t-1} \,|s_{t-1})
        mine_reward_1_input = (torch.cat([z1_t, z2_t, a_t_1, z1_t_1.detach(), z2_t_1.detach()], dim=-1), r_t)
        mine_reward_2_input = (torch.cat([z3_t, z4_t, a_t_1, z1_t_1.detach(), z2_t_1.detach()], dim=-1), r_t)
        mine_action_1_input = (torch.cat([z1_t, z3_t, z1_t_1.detach(), z2_t_1.detach(), z3_t_1.detach(), z4_t_1.detach()], dim=-1), a_t_1)
        mine_action_2_input = (torch.cat([z2_t, z4_t, z1_t_1.detach(), z2_t_1.detach(), z3_t_1.detach(), z4_t_1.detach()], dim=-1), a_t_1)


        mine_inner_loss1 = - self.mine_reward_1(mine_reward_1_input[0].detach(), mine_reward_1_input[1].detach())
        mine_inner_loss2 = - self.mine_reward_2(mine_reward_2_input[0].detach(), mine_reward_2_input[1].detach())
        mine_inner_loss3 = - self.mine_action_1(mine_action_1_input[0].detach(), mine_action_2_input[1].detach())
        mine_inner_loss4 = - self.mine_action_2(mine_action_2_input[0].detach(), mine_action_2_input[1].detach())


        mine_outer_loss_reward = - (torch.clamp(self.mine_reward_1(mine_reward_1_input[0], mine_reward_1_input[1]), min=0) - torch.clamp(self.mine_reward_2(mine_reward_2_input[0], mine_reward_2_input[1]), min=0))
        
        mine_outer_loss_action = - (torch.clamp(self.mine_action_1(mine_action_1_input[0], mine_action_1_input[1]), min=0) - torch.clamp(self.mine_action_2(mine_action_2_input[0], mine_action_2_input[1]), min=0))
        

        # VAE ELBO loss: recon_loss + kld_loss
        recon_loss = self.reconstruction_loss(x[:,:self.lag], x_recon[:,:self.lag], self.decoder_dist) + \
        (self.reconstruction_loss(x[:,self.lag:], x_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        recon_reward_loss = self.reconstruction_loss(r[:,:self.lag], r_recon[:,:self.lag], self.decoder_dist) + \
        (self.reconstruction_loss(r[:,self.lag:], r_recon[:,self.lag:], self.decoder_dist))/(length-self.lag)
        if self.transition_prior is not None:
            q_dist = D.Normal(mus, torch.exp(logvars / 2))
            log_qz = q_dist.log_prob(zs)
            # Past KLD
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
            log_pz_normal = torch.sum(torch.sum(p_dist.log_prob(zs[:,:self.lag]),dim=-1),dim=-1)
            log_qz_normal = torch.sum(torch.sum(log_qz[:,:self.lag],dim=-1),dim=-1)
            kld_normal = log_qz_normal - log_pz_normal
            kld_normal = kld_normal.mean()
            # Future KLD
            log_qz_laplace = log_qz[:,self.lag:]
            residuals, logabsdet = self.transition_prior(zs, a)
            sum_log_abs_det_jacobians = sum_log_abs_det_jacobians + logabsdet
            log_pz_laplace = torch.sum(self.base_dist.log_prob(residuals), dim=1) + sum_log_abs_det_jacobians
            kld_laplace = (torch.sum(torch.sum(log_qz_laplace,dim=-1),dim=-1) - log_pz_laplace) / (length-self.lag)
            kld_laplace = kld_laplace.mean()
        else:
            q_dist = D.Normal(mus[:, :self.lag], torch.exp(logvars[:, :self.lag] / 2))
            p_dist = D.Normal(torch.zeros_like(mus[:,:self.lag]), torch.ones_like(logvars[:,:self.lag]))
            kld_normal = torch.mean(torch.sum(torch.sum(torch.distributions.kl_divergence(q_dist, p_dist), dim=-1), dim=-1))
            # Future KLD
            q_dist_future = D.Normal(mus[:, self.lag:], torch.exp(logvars / 2)[:, self.lag:])
            p_dist_future = D.Normal(torch.zeros_like(mus[:, self.lag:]), torch.ones_like(logvars[:,self.lag:]))
            kld_laplace = torch.mean(torch.sum(torch.sum(torch.distributions.kl_divergence(q_dist_future, p_dist_future), dim=-1), dim=-1)) / (length - self.lag)

        # VAE training
        if self.trainer.current_epoch < self.delta_epoch:
            loss = recon_loss + recon_reward_loss  + self.beta * kld_normal + self.gamma * kld_laplace
        else:
            loss = recon_loss + recon_reward_loss  + self.beta * kld_normal + self.gamma * kld_laplace + self.delta * mine_outer_loss_reward + self.delta * mine_outer_loss_action
        loss_aux = mine_inner_loss1 + mine_inner_loss2 + mine_inner_loss3 + mine_inner_loss4

        # Compute R2
        zt_recon = mus.view(-1, self.z_dim).detach().cpu().numpy()
        z1_recon, z2_recon, z3_recon, z4_recon =  np.split(zt_recon, [self.z1_dim, self.z1_dim+self.z2_dim, self.z1_dim+self.z2_dim+self.z3_dim], axis=-1)
        train_z1_recon, test_z1_recon = np.split(z1_recon, 2, axis=0)
        train_z2_recon, test_z2_recon = np.split(z2_recon, 2, axis=0)
        train_z3_recon, test_z3_recon = np.split(z3_recon, 2, axis=0)
        train_z4_recon, test_z4_recon = np.split(z4_recon, 2, axis=0)

        zt_true = batch["yt"].view(-1, sum(self.z_dim_true_list)).detach().cpu().numpy()
        z1_true, z2_true, z3_true, z4_true =  np.split(zt_true, [self.z_dim_true_list[0], self.z_dim_true_list[0]+self.z_dim_true_list[1], self.z_dim_true_list[0]+self.z_dim_true_list[1]+self.z_dim_true_list[2]], axis=-1)
        train_z1_true, test_z1_true = np.split(z1_true, 2, axis=0)
        train_z2_true, test_z2_true = np.split(z2_true, 2, axis=0)
        train_z3_true, test_z3_true = np.split(z3_true, 2, axis=0)
        train_z4_true, test_z4_true = np.split(z4_true, 2, axis=0)
        r21 = compute_r2(train_z1_recon, train_z1_true, test_z1_recon, test_z1_true)
        r22 = compute_r2(train_z2_recon, train_z2_true, test_z2_recon, test_z2_true)
        r23 = compute_r2(train_z3_recon, train_z3_true, test_z3_recon, test_z3_true)
        r24 = compute_r2(train_z4_recon, train_z4_true, test_z4_recon, test_z4_true)
        ave_r2 = (r21 + r22 + r23 + r24) / 4.0
        r21h = compute_r2(train_z1_true, train_z1_recon, test_z1_true, test_z1_recon)
        r22h = compute_r2(train_z2_true, train_z2_recon, test_z2_true, test_z2_recon)
        r23h = compute_r2(train_z3_true, train_z3_recon, test_z3_true, test_z3_recon)
        r24h = compute_r2(train_z4_true, train_z4_recon, test_z4_true, test_z4_recon)
        ave_r2h = (r21h + r22h + r23h + r24h) / 4.0
        
        self.log("train_elbo_loss", loss)
        self.log("train_recon_obs_loss", recon_loss)
        self.log("train_recon_reward_loss", recon_reward_loss)
        self.log("mine_inner_loss1", mine_inner_loss1)
        self.log("mine_inner_loss2", mine_inner_loss2)
        self.log("mine_inner_loss3", mine_inner_loss3)
        self.log("mine_inner_loss4", mine_inner_loss4)
        self.log("mine_outer_loss_reward", mine_outer_loss_reward)
        self.log("mine_outer_loss_action", mine_outer_loss_action)
        self.log("train_kld_normal", kld_normal)
        self.log("train_kld_laplace", kld_laplace)
        self.log("r21", r21)
        self.log("r22", r22)
        self.log("r23", r23)
        self.log("r24", r24)
        self.log("ave_r2", ave_r2)
        self.log("r21h", r21h)
        self.log("r22h", r22h)
        self.log("r23h", r23h)
        self.log("r24h", r24h)
        self.log("ave_r2h", ave_r2h)
        return {'opt_r': loss, 'opt_aux': loss_aux}

    def sample(self, n=64):
        with torch.no_grad():
            e = torch.randn(n, self.z_dim, device=self.device)
            eps, _ = self.spline.inverse(e)
        return eps

    def configure_optimizers(self):
        opt_r = torch.optim.AdamW(filter(lambda p: p.requires_grad, get_parameters(self.representation_model_list)), lr=self.lr, betas=(0.9, 0.999), weight_decay=0.0001)
        opt_aux = torch.optim.AdamW(filter(lambda p: p.requires_grad, get_parameters(self.aux_model_list)), lr=self.aux_lr, betas=(0.9, 0.999), weight_decay=0.0001)
        return [opt_r, opt_aux]
    
    def load(self, checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        load_state_dict = {}
        for key, value in state_dict['state_dict'].items():
            if not key.startswith('mine'):
                load_state_dict[key] = value
        self.load_state_dict(load_state_dict, strict=False)
        # print(state_dict['state_dict'].keys())