"""Temporal VAE with gaussian margial and laplacian transition prior"""

import torch
import numpy as np
import torch.nn as nn
import lightning.pytorch as pl
import torch.distributions as D
from torch.nn import functional as F
# from .components.beta import BetaVAE_MLP
# from .components.transition import (MBDTransitionPrior, 
#                                     NPTransitionPrior,
#                                     NPDTransitionPrior)
from .mlp import NLayerMLP
from .mine import MINE
from ..metrics.correlation import compute_mcc, compute_r2
from LiLY.tools.utils import get_parameters
from .drssm_utils import DRSSMUtils, RSSMContState, RSSMDiscState

import ipdb as pdb


import torch.nn as nn
import torch.distributions as td
from typing import *

class TemporalPrior(nn.Module):
    def __init__(self, deter_size_s1, deter_size_s2, deter_size_s3, deter_size_s4,
                    stoch_size_s1, stoch_size_s2, stoch_size_s3, stoch_size_s4, node_size, rssm_type="continuous", act_fn=nn.ELU) -> None:
        super().__init__()
        self.deter_size_s1 = deter_size_s1
        self.deter_size_s2 = deter_size_s2
        self.deter_size_s3 = deter_size_s3
        self.deter_size_s4 = deter_size_s4
        self.stoch_size_s1 = stoch_size_s1
        self.stoch_size_s2 = stoch_size_s2
        self.stoch_size_s3 = stoch_size_s3
        self.stoch_size_s4 = stoch_size_s4
        assert(((self.deter_size_s1 == 0) == (self.stoch_size_s1 == 0)) and ((self.deter_size_s2 == 0) == (self.stoch_size_s2 == 0)) and ((self.deter_size_s3 == 0) == (self.stoch_size_s3 == 0)) and ((self.deter_size_s4 == 0) == (self.stoch_size_s4 == 0)))
        self.node_size = node_size
        self.rssm_type = rssm_type
        self.act_fn = act_fn
        self.prior_s1, self.prior_s2, self.prior_s3, self.prior_s4 = None, None, None, None
        self._build_model()

    def _build_model(self):
        if self.deter_size_s1 > 0:
            temporal_prior_s1 = [nn.Linear(self.deter_size_s1, self.node_size), self.act_fn()]
        if self.deter_size_s2 > 0:
            temporal_prior_s2 = [nn.Linear(self.deter_size_s2, self.node_size), self.act_fn()]
        if self.deter_size_s3 > 0:
            temporal_prior_s3 = [nn.Linear(self.deter_size_s3, self.node_size), self.act_fn()]
        if self.deter_size_s4 > 0:
            temporal_prior_s4 = [nn.Linear(self.deter_size_s4, self.node_size), self.act_fn()]
        if self.rssm_type == 'discrete':
            raise NotImplementedError
        elif self.rssm_type == 'continuous':
            if self.deter_size_s1 > 0:
                temporal_prior_s1 += [nn.Linear(self.node_size, 2 * self.stoch_size_s1)]
                self.prior_s1 = nn.Sequential(*temporal_prior_s1)
            if self.deter_size_s2 > 0:
                temporal_prior_s2 += [nn.Linear(self.node_size, 2 * self.stoch_size_s2)]
                self.prior_s2 = nn.Sequential(*temporal_prior_s2)
            if self.deter_size_s3 > 0:
                temporal_prior_s3 += [nn.Linear(self.node_size, 2 * self.stoch_size_s3)]
                self.prior_s3 = nn.Sequential(*temporal_prior_s3)
            if self.deter_size_s4 > 0:
                temporal_prior_s4 += [nn.Linear(self.node_size, 2 * self.stoch_size_s4)]
                self.prior_s4 = nn.Sequential(*temporal_prior_s4)

    def forward(self, input_tensor):
        input_deter_s1, input_deter_s2, input_deter_s3, input_deter_s4 = torch.split(input_tensor, [self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4], dim=-1)
        mean_result_list = []
        std_result_list = []
        if self.rssm_type == 'discrete':
            raise NotImplementedError
        if self.rssm_type == 'continuous':
            if self.prior_s1 is not None:
                output_stoch_s1_mean, output_stoch_s1_std = torch.chunk(self.prior_s1(input_deter_s1), 2, dim=-1)
                mean_result_list.append(output_stoch_s1_mean)
                std_result_list.append(output_stoch_s1_std)
            if self.prior_s2 is not None:
                output_stoch_s2_mean,  output_stoch_s2_std= torch.chunk(self.prior_s2(input_deter_s2), 2, dim=-1)
                mean_result_list.append(output_stoch_s2_mean)
                std_result_list.append(output_stoch_s2_std)
            if self.prior_s3 is not None:
                output_stoch_s3_mean,  output_stoch_s3_std= torch.chunk(self.prior_s3(input_deter_s3), 2, dim=-1)
                mean_result_list.append(output_stoch_s3_mean)
                std_result_list.append(output_stoch_s3_std)
            if self.prior_s4 is not None:
                output_stoch_s4_mean,  output_stoch_s4_std= torch.chunk(self.prior_s4(input_deter_s4), 2, dim=-1)
                mean_result_list.append(output_stoch_s4_mean)
                std_result_list.append(output_stoch_s4_std)
            return torch.cat(mean_result_list + std_result_list, dim=-1)
        
class TemporalPosterior(nn.Module):
    def __init__(self, deter_size_s1, deter_size_s2, deter_size_s3, deter_size_s4,
                    stoch_size_s1, stoch_size_s2, stoch_size_s3, stoch_size_s4, embedding_size, node_size, rssm_type="continuous", act_fn=nn.ELU) -> None:
        super().__init__()
        self.deter_size_s1 = deter_size_s1
        self.deter_size_s2 = deter_size_s2
        self.deter_size_s3 = deter_size_s3
        self.deter_size_s4 = deter_size_s4
        self.stoch_size_s1 = stoch_size_s1
        self.stoch_size_s2 = stoch_size_s2
        self.stoch_size_s3 = stoch_size_s3
        self.stoch_size_s4 = stoch_size_s4
        self.embedding_size = embedding_size
        assert(((self.deter_size_s1 == 0) == (self.stoch_size_s1 == 0)) and ((self.deter_size_s2 == 0) == (self.stoch_size_s2 == 0)) and ((self.deter_size_s3 == 0) == (self.stoch_size_s3 == 0)) and ((self.deter_size_s4 == 0) == (self.stoch_size_s4 == 0)))
        self.node_size = node_size
        self.rssm_type = rssm_type
        self.act_fn = act_fn
        self.posterior_s1, self.posterior_s2, self.posterior_s3, self.posterior_s4 = None, None, None, None
        self._build_model()

    def _build_model(self):
        if self.deter_size_s1 > 0:
            temporal_posterior_s1 = [nn.Linear(self.deter_size_s1 + self.deter_size_s2 + self.embedding_size, self.node_size), self.act_fn()]
        if self.deter_size_s2 > 0:
            temporal_posterior_s2 = [nn.Linear(self.deter_size_s1 + self.deter_size_s2 + self.embedding_size, self.node_size), self.act_fn()]
        if self.deter_size_s3 > 0:
            temporal_posterior_s3 = [nn.Linear(self.deter_size_s3 + self.embedding_size, self.node_size), self.act_fn()]
        if self.deter_size_s4 > 0:
            temporal_posterior_s4 = [nn.Linear(self.deter_size_s4 + self.embedding_size, self.node_size), self.act_fn()]
        if self.rssm_type == 'discrete':
            raise NotImplementedError
        elif self.rssm_type == 'continuous':
            if self.deter_size_s1 > 0:
                temporal_posterior_s1 += [nn.Linear(self.node_size, 2 * self.stoch_size_s1)]
                self.posterior_s1 = nn.Sequential(*temporal_posterior_s1)
            if self.deter_size_s2 > 0:
                temporal_posterior_s2 += [nn.Linear(self.node_size, 2 * self.stoch_size_s2)]
                self.posterior_s2 = nn.Sequential(*temporal_posterior_s2)
            if self.deter_size_s3 > 0:
                temporal_posterior_s3 += [nn.Linear(self.node_size, 2 * self.stoch_size_s3)]
                self.posterior_s3 = nn.Sequential(*temporal_posterior_s3)
            if self.deter_size_s4 > 0:
                temporal_posterior_s4 += [nn.Linear(self.node_size, 2 * self.stoch_size_s4)]
                self.posterior_s4 = nn.Sequential(*temporal_posterior_s4)

    def forward(self, input_tensor):
        input_deter_s1, input_deter_s2, input_deter_s3, input_deter_s4, input_embedding = torch.split(input_tensor, [self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4, self.embedding_size], dim=-1)
        mean_result_list = []
        std_result_list = []
        input_s1s2 = torch.cat([input_deter_s1, input_deter_s2, input_embedding], dim=-1)
        input_s3 = torch.cat([input_deter_s3, input_embedding], dim=-1)
        input_s4 = torch.cat([input_deter_s4, input_embedding], dim=-1)
        if self.rssm_type == 'discrete':
            raise NotImplementedError
        if self.rssm_type == 'continuous':
            if self.posterior_s1 is not None:
                output_stoch_s1_mean, output_stoch_s1_std = torch.chunk(self.posterior_s1(input_s1s2), 2, dim=-1)
                mean_result_list.append(output_stoch_s1_mean)
                std_result_list.append(output_stoch_s1_std)
            if self.posterior_s2 is not None:
                output_stoch_s2_mean,  output_stoch_s2_std= torch.chunk(self.posterior_s2(input_s1s2), 2, dim=-1)
                mean_result_list.append(output_stoch_s2_mean)
                std_result_list.append(output_stoch_s2_std)
            if self.posterior_s3 is not None:
                output_stoch_s3_mean,  output_stoch_s3_std= torch.chunk(self.posterior_s3(input_s3), 2, dim=-1)
                mean_result_list.append(output_stoch_s3_mean)
                std_result_list.append(output_stoch_s3_std)
            if self.posterior_s4 is not None:
                output_stoch_s4_mean,  output_stoch_s4_std= torch.chunk(self.posterior_s4(input_s4), 2, dim=-1)
                mean_result_list.append(output_stoch_s4_mean)
                std_result_list.append(output_stoch_s4_std)
            return torch.cat(mean_result_list + std_result_list, dim=-1)
    
    def posterior_s3_s4_param(self):
        params = []
        if self.posterior_s3 is not None:
            params += list(self.posterior_s3.parameters())
        if self.posterior_s4 is not None:
            params += list(self.posterior_s4.parameters())
        # print("param length: ",len(params))
        return params


class DRSSM(DRSSMUtils):

    def __init__(
        self, 
        input_dim,
        length,
        z_dim_list,
        deter_dim_list,
        action_dim, 
        lag,
        config,
        hidden_dim=128,
        trans_prior='NP',
        lr=1e-4,
        aux_lr=1e-4,
        infer_mode='F',
        beta=0.0025,
        gamma=0.0075,
        delta=0.01,
        delta_epoch=10,
        decoder_dist='gaussian',
        correlation='Pearson'):
        '''Nonlinear ICA for nonparametric stationary processes'''
        super().__init__()
        self.automatic_optimization=False
        # Transition prior must be L (Linear), NP (Nonparametric)
        assert trans_prior in ('L', 'NP')
        self.z1_dim, self.z2_dim, self.z3_dim, self.z4_dim = z_dim_list[0], z_dim_list[1], z_dim_list[2], z_dim_list[3]
        z_dim = sum(z_dim_list)
        self.z_dim = z_dim
        [self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4] = deter_dim_list
        [self.stoch_size_s1, self.stoch_size_s2, self.stoch_size_s3, self.stoch_size_s4] = z_dim_list
        self.stoch_size = z_dim
        self.action_dim = action_dim
        self.lag = lag
        self.input_dim = input_dim
        self.length = length
        self.decoder_dist = decoder_dist
        self.infer_mode = infer_mode
        # Recurrent/Factorized inference
        self.representation_model_list = []
        
        self.enc = nn.Sequential(
            nn.Linear(input_dim,hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,z_dim),
        )

        self.dec = nn.Sequential(
            nn.Linear(z_dim,hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,input_dim),
        )

        if self.deter_size_s1 > 0:
            self.rnn1 = nn.GRUCell(self.deter_size_s1, self.deter_size_s1)
        if self.deter_size_s2 > 0:
            self.rnn2 = nn.GRUCell(self.deter_size_s2, self.deter_size_s2)
        if self.deter_size_s3 > 0:
            self.rnn3 = nn.GRUCell(self.deter_size_s3, self.deter_size_s3)
        if self.deter_size_s4 > 0:
            self.rnn4 = nn.GRUCell(self.deter_size_s4, self.deter_size_s4)
        
        # embed state and action
        self._build_embed_state_action()
        self.fc_prior = self._build_temporal_prior()
        self.fc_posterior = self._build_temporal_posterior()

        self.rew_dec = NLayerMLP(in_features=self.z1_dim+self.z2_dim, out_features=1, num_layers=2)
        self.representation_model_list += [self.enc, self.dec, self.rnn1, self.rnn2, self.rnn3, self.rnn4, self.fc_prior, self.fc_posterior, self.fc_embed_s1, self.fc_embed_s2, self.fc_embed_s3 , self.fc_embed_s4, self.rew_dec]
        
        # I(s_t^{1, 2}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{3, 4}; R_{t} | a_{t-1}, s^{1, 2}_{t-1})
        # I(s_t^{1, 3}; a_{t-1}|s_{t-1})
        # I(s_t^{2, 4}; a_{t-1} \,|s_{t-1})
        self.mine_reward_1 = MINE(x_dim=self.z1_dim+self.z2_dim, y_dim=1, z_dim=self.z1_dim+self.z2_dim+action_dim)
        self.mine_reward_2 = MINE(x_dim=self.z3_dim+self.z4_dim, y_dim=1, z_dim=self.z1_dim+self.z2_dim+action_dim)
        self.mine_action_1 = MINE(x_dim=self.z1_dim+self.z3_dim, y_dim=action_dim, z_dim=self.z_dim)
        self.mine_action_2 = MINE(x_dim=self.z2_dim+self.z4_dim, y_dim=action_dim, z_dim=self.z_dim)
        # self.aux_model_list = [self.mine_reward_1, self.mine_reward_2]
        self.aux_model_list = [self.mine_reward_1, self.mine_reward_2, self.mine_action_1, self.mine_action_2]
        # base distribution for calculation of log prob under the model
        self.register_buffer('base_dist_mean', torch.zeros(self.z_dim))
        self.register_buffer('base_dist_var', torch.eye(self.z_dim))

    def _build_embed_state_action(self):
        """
        model is supposed to take in previous stochastic state and previous action
        and embed it to deter size for rnn input
        """
        self.fc_embed_s1s2, self.fc_embed_s2, self.fc_embed_s3, self.fc_embed_s4 = None, None, None, None
        if self.deter_size_s1 > 0:
            fc_embed_s1s2a = [nn.Linear(self.stoch_size_s1 + self.stoch_size_s2 + self.action_size, self.deter_size_s1), self.act_fn()]
            self.fc_embed_s1 = nn.Sequential(*fc_embed_s1s2a)
        if self.deter_size_s2 > 0:
            fc_embed_s1s2 = [nn.Linear(self.stoch_size_s1 + self.stoch_size_s2, self.deter_size_s2), self.act_fn()]
            self.fc_embed_s2 = nn.Sequential(*fc_embed_s1s2)
        if self.deter_size_s3 > 0:
            fc_embed_sa = [nn.Linear(self.stoch_size + self.action_size, self.deter_size_s3), self.act_fn()]
            self.fc_embed_s3 = nn.Sequential(*fc_embed_sa)
        if self.deter_size_s4 > 0:
            fc_embed_s = [nn.Linear(self.stoch_size, self.deter_size_s4), self.act_fn()]
            self.fc_embed_s4 = nn.Sequential(*fc_embed_s)

    def _build_temporal_prior(self):
        """
        model is supposed to take in latest deterministic state
        and output prior over stochastic state
        """
        return TemporalPrior(self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4, self.stoch_size_s1, self.stoch_size_s2, self.stoch_size_s3, self.stoch_size_s4, self.node_size,self.rssm_type,  self.act_fn)

    def _build_temporal_posterior(self):
        """
        model is supposed to take in latest embedded observation and deterministic state
        and output posterior over stochastic states
        """
        # temporal_posterior = [nn.Linear(self.deter_size + self.embedding_size, self.node_size)]
        # temporal_posterior += [self.act_fn()]
        # if self.rssm_type == 'discrete':
        #     temporal_posterior += [nn.Linear(self.node_size, self.stoch_size)]
        # elif self.rssm_type == 'continuous':
        #     temporal_posterior += [nn.Linear(self.node_size, 2 * self.stoch_size)]
        # return nn.Sequential(*temporal_posterior)
        return TemporalPosterior(self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4, self.stoch_size_s1, self.stoch_size_s2, self.stoch_size_s3, self.stoch_size_s4, self.embedding_size, self.node_size, self.rssm_type,  self.act_fn)
    
    def forward_embed_state(self, stoch_state, prev_action):
        s1, s2, s3, s4 = torch.split(stoch_state, [self.stoch_size_s1, self.stoch_size_s2, self.stoch_size_s3, self.stoch_size_s4], dim=-1)
        result_list = []
        if self.fc_embed_s1 is not None:
            result_list.append(self.fc_embed_s1(torch.cat([s1, s2, prev_action], dim=-1)))
        if self.fc_embed_s2 is not None:
            result_list.append(self.fc_embed_s2(torch.cat([s1, s2], dim=-1)))
        if self.fc_embed_s3 is not None:
            result_list.append(self.fc_embed_s3(torch.cat([stoch_state, prev_action], dim=-1)))
        if self.fc_embed_s4 is not None:
            result_list.append(self.fc_embed_s4(stoch_state))
        state_embed = torch.cat(result_list, dim=-1)
        return state_embed
    
    def forward_rnn(self, state_embed, prev_deter_state):
        prev_deter_state_s1, prev_deter_state_s2, prev_deter_state_s3, prev_deter_state_s4 = torch.split(prev_deter_state, [self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4], dim=-1)
        state_embed_s1, state_embed_s2, state_embed_s3, state_embed_s4 = torch.split(state_embed, [self.deter_size_s1, self.deter_size_s2, self.deter_size_s3, self.deter_size_s4], dim=-1)
        result_list = []
        if self.deter_size_s1 > 0:
            result_list.append(self.rnn1(state_embed_s1, prev_deter_state_s1))
        if self.deter_size_s2 > 0:
            result_list.append(self.rnn2(state_embed_s2, prev_deter_state_s2))
        if self.deter_size_s3 > 0:
            result_list.append(self.rnn3(state_embed_s3, prev_deter_state_s3))
        if self.deter_size_s4 > 0:
            result_list.append(self.rnn4(state_embed_s4, prev_deter_state_s4))
        deter_state = torch.cat(result_list, dim=-1)
        return deter_state
    def rssm_imagine(self, prev_action, prev_rssm_state, nonterms=True):
            # deter_dict = self.get_deter_state_dict(prev_rssm_state)
            state_embed = self.forward_embed_state(prev_rssm_state.stoch*nonterms, prev_action)
            deter_state = self.forward_rnn(state_embed, prev_rssm_state.deter*nonterms)
            if self.rssm_type == 'discrete':
                prior_logit = self.fc_prior(deter_state)
                stats = {'logit':prior_logit}
                prior_stoch_state = self.get_stoch_state(stats)
                prior_rssm_state = RSSMDiscState(prior_logit, prior_stoch_state, deter_state)
            elif self.rssm_type == 'continuous':
                prior_mean, prior_std = torch.chunk(self.fc_prior(deter_state), 2, dim=-1)
                stats = {'mean':prior_mean, 'std':prior_std}
                prior_stoch_state, std = self.get_stoch_state(stats)
                prior_rssm_state = RSSMContState(prior_mean, std, prior_stoch_state, deter_state)
            return prior_rssm_state
        
    def rssm_observe(self, obs_embed, prev_action, prev_nonterm, prev_rssm_state):
        prior_rssm_state = self.rssm_imagine(prev_action, prev_rssm_state, prev_nonterm)
        deter_state = prior_rssm_state.deter
        x = torch.cat([deter_state, obs_embed], dim=-1)
        if self.rssm_type == 'discrete':
            posterior_logit = self.fc_posterior(x)
            stats = {'logit':posterior_logit}
            posterior_stoch_state = self.get_stoch_state(stats)
            posterior_rssm_state = RSSMDiscState(posterior_logit, posterior_stoch_state, deter_state)

        elif self.rssm_type == 'continuous':
            posterior_mean, posterior_std = torch.chunk(self.fc_posterior(x), 2, dim=-1)
            stats = {'mean':posterior_mean, 'std':posterior_std}
            posterior_stoch_state, std = self.get_stoch_state(stats)
            posterior_rssm_state = RSSMContState(posterior_mean, std, posterior_stoch_state, deter_state)
        return prior_rssm_state, posterior_rssm_state

    def rollout_observation(self, seq_len: int, obs_embed: torch.Tensor, action: torch.Tensor, nonterms: torch.Tensor, prev_rssm_state):
        priors = []
        posteriors = []
        for t in range(seq_len):
            prev_action = action[t]*nonterms[t]
            prior_rssm_state, posterior_rssm_state = self.rssm_observe(obs_embed[t], prev_action, nonterms[t], prev_rssm_state)
            priors.append(prior_rssm_state)
            posteriors.append(posterior_rssm_state)
            prev_rssm_state = posterior_rssm_state
        prior = self.rssm_stack_states(priors, dim=0)
        post = self.rssm_stack_states(posteriors, dim=0)
        return prior, post


    @property
    def base_dist(self):
        # Noise density function
        return D.MultivariateNormal(self.base_dist_mean, self.base_dist_var)

    def inference(self, ft, random_sampling=True):
        ## bidirectional lstm/gru 
        # input: (batch, seq_len, z_dim)
        # output: (batch, seq_len, z_dim)
        output, h_n = self.rnn(ft)
        batch_size, length, _ = output.shape
        # beta, hidden = self.gru(ft, hidden)
        ## sequential sampling & reparametrization
        ## transition: p(zt|z_tau)
        zs, mus, logvars = [], [], []
        for tau in range(self.lag):
            zs.append(torch.ones((batch_size, self.z_dim), device=output.device))

        for t in range(length):
            mid = torch.cat(zs[-self.lag:], dim=1)
            inputs = torch.cat([mid, output[:,t,:]], dim=1)    
            distributions = self.net(inputs)
            mu = distributions[:, :self.z_dim]
            logvar = distributions[:, self.z_dim:]
            zt = self.reparameterize(mu, logvar, random_sampling)
            zs.append(zt)
            mus.append(mu)
            logvars.append(logvar)

        zs = torch.squeeze(torch.stack(zs, dim=1))
        # Strip the first L zero-initialized zt 
        zs = zs[:,self.lag:]
        mus = torch.squeeze(torch.stack(mus, dim=1))
        logvars = torch.squeeze(torch.stack(logvars, dim=1))
        return zs, mus, logvars
    
    def reparameterize(self, mean, logvar, random_sampling=True):
        if random_sampling:
            eps = torch.randn_like(logvar)
            std = torch.exp(0.5*logvar)
            z = mean + eps*std
            return z
        else:
            return mean

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0

        if distribution == 'bernoulli':
            recon_loss = F.binary_cross_entropy_with_logits(
                x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'gaussian':
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        elif distribution == 'sigmoid_gaussian':
            x_recon = F.sigmoid(x_recon)
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

        return recon_loss
    
    def aux_reward_loss(self, distribution, y):
        mle_loss = -distribution.log_prob(y).mean()
        mse_loss = F.mse_loss(distribution.mean, y)
        return mle_loss, mse_loss

    def forward(self, batch):
        x, y = batch['xt'], batch['yt']
        batch_size, length, _ = x.shape
        x_flat = x.view(-1, self.input_dim)
        if self.infer_mode == 'R':
            ft = self.enc(x_flat)
            ft = ft.view(batch_size, length, -1)
            zs, mus, logvars = self.inference(ft, random_sampling=True)
        elif self.infer_mode == 'F':
            _, mus, logvars, zs = self.net(x_flat)
        return zs, mus, logvars       