
from models import ODEfunc, ODEnet, MLPSimple
import torch.nn as nn
from cnf import CNF, SequentialFlow
from plotters import plot_trajectories, plot_trajectories_test, plot_trajectories_test_all, plot_trajectories_with_logprobabilities, plot_forward_backcward, plot_trajectories_test_simple
import pytorch_lightning as pl
import argparse
import torch
import pandas as pd
import numpy as np
import random
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable
from torch.distributions import Normal


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_flow( obs_dim, state_dim, contex_dim, num_blocks, conditional=True, use_adjoint=True):
    def build_cnf():
        diffeq = ODEnet(
            hidden_dims = tuple(map(int,[128, 128, 128])),
            state_dim=(state_dim,),
            contex_dim=contex_dim,
            layer_type='concatsquash',
            nonlinearity='tanh',
        )
        odefunc = ODEfunc(
            diffeq = diffeq,
            obs_dim = obs_dim
        )
        cnf = CNF(
            odefunc=odefunc,
            T=1.0,
            obs_dim = obs_dim,
            train_T=True,
            conditional=conditional,
            solver= 'dopri5', 
            use_adjoint=use_adjoint,
            atol=1e-4, 
            rtol=1e-4, 
        )
        return cnf


    chain = [build_cnf() for _ in range(num_blocks)]
    model = SequentialFlow(chain)

    return model



class cnfModule(pl.LightningModule):
    def __init__(self, loss_func,obs_dim, state_dim, num_blocks, action_dim, 
                 lr,   reg1, reg2, dataset_name, MLP_decoding_mode = False, use_adjoint=True, **kwargs): #decoder_depth, decoder_dim
        super().__init__()
        self.dataset_name = dataset_name
        self.context_dim = obs_dim + action_dim #action_dim #
        self.obs_dim = obs_dim
        self.flow = build_flow(self.obs_dim, state_dim, self.context_dim, num_blocks, True, use_adjoint=use_adjoint)
        self.lr = lr
        #self.lr_gamma = lr_gamma
        self.reg1 = reg1
        self.reg2 = reg2
        self.state_dim = state_dim
        self.shift_factor = Variable(torch.zeros(self.obs_dim), requires_grad=True)
        self.scale_factor = Variable(torch.ones(self.obs_dim), requires_grad=True)
        self.loss_func = loss_func
        if torch.cuda.is_available():
            self.normal = Normal(torch.tensor([0.0]).to(device=torch.device("cuda")), torch.tensor([1.0]).to(device=torch.device("cuda")))
            self.normal_state = Normal(torch.tensor([0.0]).to(device=torch.device("cuda")), torch.tensor([0.2]).to(device=torch.device("cuda")))
        else:
            self.normal = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
            self.normal_state = Normal(torch.tensor([0.0]), torch.tensor([0.2]))
        print("Number of trainable parameters of Point CNF: {}".format(count_parameters(self.flow)))


    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
        #scheduler = {"monitor": "val_loss", "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode = "min", factor = 0.5, patience = 10, verbose = True)}
        #return {"optimizer": optimizer}, "lr_scheduler":scheduler}
        #scheduler = MultiStepLR(optimizer, 
        #                milestones=list(np.arange(5,250,10)),#[8, 24, 28], # List of epoch indices
        #                gamma = self.lr_gamma) # Multiplicative factor of learning rate decay
        
        #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.lr_gamma)
        return [optimizer]#, [scheduler]

    

    def forward(self, state0, X, actions, times, logpx=None, parent_forcing=False, reverse=False, deBug=False, addNoise=True):
        """
        actions: batchsize
        obs_times: the times between observations where we integrate (batch size x timesteps)
        X: observations (batch size x obs dim)
        """
        if (self.reg1>0 or self.reg2>0) and logpx==None: # we need divergence to get the regularization terms
            _logpx = torch.zeros(*X.shape[:-1], 1, device= self.device)  #torch.sum(logpx[:,:-self.obs_dim], dim=-1, keepdim= True)#prod #logpx
        else:
            _logpx = logpx

        #run the model through the full sequence of observations
        #create the context
        batch_size = times.size()[0]
        time_steps = times.size()[1]

        context = torch.cat((actions, X), dim=2) #batch size x timesteps x context dim # actions
        #state0[:,-self.obs_dim:] =(X / self.scale_factor) - self.shift_factor
        startState = state0
        
        if addNoise:
            noise = torch.randn_like(context)*0.2
            context = ((noise + context).detach() - context).detach() + context
            noise_s = torch.randn_like(startState)*0.2
            startState = ((noise_s + startState).detach() - startState).detach() + startState

        # Use given time array or if contains Nans replace with evenly spaced range
        if torch.isnan(times).any():
            t = torch.arange(times.shape[1], dtype=torch.float)
        else:
            t = times[0,:]

        # add logpx when training with correct loss!
        if deBug:
            print('before cnf')
            print(startState.mean())
            print(context.mean())
            print(reverse)

        
        if _logpx is None:
            state = self.flow(
                x=startState, 
                context=context, 
                logpx=_logpx, 
                integration_times=t, 
                reverse=reverse,
                deBug=deBug,
                parent_forcing= parent_forcing,
                addNoise= False #addNoise
                ) 
            regTerm1=torch.tensor([0.0])
            regTerm2=torch.tensor([0.0])
        else:
            state, _logpx, regTerm1, regTerm2 = self.flow(
                x=startState, 
                context=context, 
                logpx=_logpx, 
                integration_times=t,
                reverse=reverse,
                deBug=deBug,
                parent_forcing= parent_forcing,
                addNoise= False#addNoise
                ) #, _logpx

        
        obsPred = state[:,:,-self.obs_dim:] #+ self.shift_factor) * self.scale_factor
        
        return obsPred, state, _logpx, regTerm1, regTerm2

    def startState(self, X0, batch_size):
        mask = ~torch.isnan(X0)
        # sample from normal distribution
        s0 = self.normal.sample(sample_shape = (batch_size, self.state_dim)).squeeze(-1)
        # replace observed starting sttates with a normal distribution sampled with observed mean and small variance

        s0[:, -self.obs_dim:][mask] = self.normal_state.sample(sample_shape = (batch_size, self.obs_dim)).squeeze(-1)[mask] + X0[mask]
        #s0 = self.normal_state.sample(sample_shape = (batch_size, self.obs_dim)).squeeze(-1) + X0
        return s0 #torch.cat((l0,s0), dim=1)
    
    def training_step(self, batch, batch_idx):
        X, actions, t_X , weights  = batch
        weights= weights.type(torch.bool)
        """
        X: observations (batch size x obs times x obs dim)
        actions: (batch size x action dim) --> batch size x obs times x action dim
        """
        #torch.cuda.empty_cache()
        # change the data generation such that the actions are randomized
        batch_size = X.size()[0]
        time_steps = X.size()[1]
        logpx = None 

        #sample initial state from a normal distribution
        state0 = self.startState(X[:,0,:], batch_size) #self.normal.sample(sample_shape = (batch_size, self.state_dim)).squeeze(-1).to(X)
        parent_forcing = False #bool(random.getrandbits(1)) #batch_idx %2
        Y = X[:,1:,:]
        x= X[:,:-1,:].clone()
        #replace missing values with some flag value
        if torch.isnan(X).any():
            x[torch.isnan(x)] = 100
        
        Y_hat, states, logpx , _regTerm1, _regTerm2= self(
            state0=state0, 
            X=x, 
            actions=actions[:,:-1,:], 
            times=t_X, 
            logpx=logpx,
            parent_forcing=parent_forcing
            )
        loss = self.compute_losses(X, actions[:,:-1,:], t_X, states, self.loss_func, weights)
        mse = self.compute_losses(X.clone(), actions[:,:-1,:].clone(), t_X.clone(), states.clone(),'mse', weights)
        div = self.compute_losses(X, actions[:,:-1,:], t_X, states,'div', weights)



        loss = 0.6*(torch.mean(loss)) + self.reg1 *torch.mean(div) + self.reg2*torch.mean(_regTerm1) + self.reg2*torch.mean(_regTerm2) #/(t_X.size()[1]-1)
        mse = torch.mean(torch.sqrt(mse))
        div = torch.mean(div)


        on_step = True #False
        on_epoch = False 
        self.log("train_loss",loss, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
        self.log("train_rmse",mse, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
        self.log("train_div",div, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, actions, t_X, weights = batch
        weights= weights.type(torch.bool)

        ''' 
        X: observations (batch size x obs times x obs dim)
        actions: (batch size x action dim)
        '''
        batch_size = X.size()[0]
        time_steps = X.size()[1]

        #sample initial state from a normal distribution
        state0 = self.startState(X[:,0,:], batch_size)
        #self.normal.sample(sample_shape = (batch_size, self.state_dim)).squeeze(-1).to(X)
        logpx = None #self.normal.log_prob(state0)

        x= X[:,:-1,:].clone()
        #replace missing values with some flag value
        if torch.isnan(X).any():
            x[torch.isnan(x)] = 100

        Y_hat, states, logpx, _regTerm1, _regTerm2= self(
            state0=state0, 
            X=x, 
            actions=actions[:,:-1,:], 
            times=t_X, 
            logpx=logpx,
            parent_forcing=False,
            addNoise = False
            )
        loss = self.compute_losses(X, actions[:,:-1,:], t_X, states, self.loss_func, weights)
        mse = self.compute_losses(X, actions[:,:-1,:], t_X, states,'mse', weights)

        loss = torch.mean(loss) #/(t_X.size()[1]-1)
        mse =  torch.mean(torch.sqrt(mse)) #/(t_X.size()[1]-1)
 
        on_step = False
        on_epoch = True #False 
        self.log("val_rmse",mse, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
        self.log("val_loss",loss, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
        if (self.current_epoch%10)==0:
            # loglikelihood estimates
            #loglik, ll_std = self.compute_logLikelihood(batch)
            #self.log("val_mean_loglik",loglik, on_step = on_step, on_epoch = on_epoch, sync_dist=True)
            #self.log("val_loglik_mean_std",ll_std, on_step = on_step, on_epoch = on_epoch, sync_dist=True)


            #plot_trajectories( X.detach().clone(), Y_hat.detach().clone(), t_X.detach().clone(), self.dataset_name , chart_type = "val", logger=self.logger )
            

            # sample trajectories and find the average 
            N = 20
            X = X[0,:,:][None,].repeat(N,1,1)
            actions = actions[0,:,:][None,].repeat(N,1,1)
            state0 = self.startState(X[:,0,:], N)#self.normal.sample(sample_shape = (N, self.state_dim)).squeeze(-1).to(X)
            logpx = None
            x= X[:,:-1,:].clone()
            #replace missing values with some flag value
            if torch.isnan(X).any():
                x[torch.isnan(x)] = 100
            Y_hat, states, logpx, _, _= self(
                state0=state0, 
                X=x, 
                actions=actions[:,:-1,:], 
                times=t_X, 
                logpx=logpx,
                parent_forcing=False,
                addNoise = False
                )
            weights= weights[0,:,:][None,].repeat(N,1,1)
            mse = self.compute_losses(X, actions[:,:-1,:], t_X, states,'mse', weights)

            compare_traj= Y_hat.clone().detach()
            plot_trajectories_test_all( X.detach().clone(), Y_hat.detach().clone(), t_X.detach().clone(), self.dataset_name, chart_type = "Forward_val", logger=self.logger )

            
            state_kl = torch.cat((states[:,-1,:].detach().clone()[:,:-self.obs_dim], X[:,-1,:]) , dim=1).detach()
            Y_hat_r, states, logpx, _, _= self(
                state0=states[:,-1,:], 
                X=x, 
                actions=actions[:,:-1,:], 
                times=t_X, 
                logpx=logpx,
                parent_forcing=False,
                addNoise = False,
                reverse = True
                )
            mse = self.compute_losses(X, actions[:,:-1,:], t_X, states,'mse', weights)
            Y_hat_kl, states_kl, logpx , _, _= self(
                state0=state_kl, 
                X=x, 
                actions=actions[:,:-1,:], 
                times=t_X, 
                logpx=logpx,
                parent_forcing=False,
                addNoise = False,
                reverse = True
                )
            div_rmse = torch.sqrt(torch.mean((compare_traj-torch.flip(Y_hat_r, dims=[1])).pow(2)))
            print('mean mse ' +str(torch.mean(torch.mean(compare_traj, dim=0)-torch.mean(torch.flip(Y_hat_r, dims=[1]),dim=0) )))
            self.log("val_diverg_rmse",div_rmse, on_step = on_step, on_epoch = on_epoch, sync_dist=True)

            plot_forward_backcward(X, Y_hat, torch.flip(Y_hat_r, dims=[1]), t_X.detach().clone(), actions, self.dataset_name, chart_type = "validation bakward comp", logger=self.logger )
            plot_forward_backcward(X, Y_hat, torch.flip(Y_hat_kl, dims=[1]), t_X.detach().clone(), actions, self.dataset_name, chart_type = "validation bakward comp kl", logger=self.logger )
            plot_trajectories_test_all( X.detach().clone(), torch.flip(Y_hat_r.detach().clone(), dims=[1]), t_X.detach().clone(), self.dataset_name, chart_type = "Backward_val", logger=self.logger )
            plot_trajectories_test_all( X.detach().clone(), torch.flip(Y_hat_kl.detach().clone(), dims=[1]), t_X.detach().clone(), self.dataset_name, chart_type = "Backward_kl_val", logger=self.logger )
        return loss


    def test_step(self, batch, batch_idx):
        #torch.set_grad_enabled(True)
        self.reg1 = 0
        self.reg2 = 0

        X, actions, t_X, weights = batch
        #self.compute_logLikelihood(batch)
        X_all = X.detach().clone()
        # Mask part of the data!

        total_samples=X.shape[1]
        missing_prop = 0.0
        samples=int(missing_prop*total_samples)
        mask_0 =random.sample(range(0, total_samples), samples)
        mask_1 =random.sample(range(0, total_samples), samples)
        X[:,mask_0,0] = np.nan
        X[:,mask_1,1] = np.nan

        weights= weights.type(torch.bool)
        batch_size = X.size()[0]
        time_steps = X.size()[1]
        state0 = self.startState(X[:,0,:], batch_size)# self.normal.sample(sample_shape = (batch_size, self.state_dim)).squeeze(-1).to(X)
        logpx = None #self.normal.log_prob(state0)

        Y_hat, states, logpx , _regTerm1, _regTerm2= self(
            state0=state0.requires_grad_(True), 
            X=X[:,:-1,:].clone().requires_grad_(True), 
            actions=actions[:,:-1,:].requires_grad_(True), 
            times=t_X.requires_grad_(True), 
            logpx=logpx,#.requires_grad_(True),
            parent_forcing=False,
            addNoise = False,
            reverse =False
            )
        loss = self.compute_losses(X_all, actions[:,:-1,:], t_X, states, self.loss_func, weights)
        mse = self.compute_losses(X_all, actions[:,:-1,:], t_X, states,'mse', weights)
        
        loss = torch.mean(loss) #/(t_X.size()[1]-1)
        rmse = torch.mean(torch.sqrt(mse))
        rmse_std = torch.std(torch.mean(torch.sqrt(mse),dim=(1,2)))
        #with torch.set_grad_enabled(True): 
        #    loglik, ll_std = self.compute_logLikelihood(batch)

        self.log("test_rmse",rmse)
        self.log("test_rmse_std",rmse_std)
        self.log("test_loss",loss)
        #self.log("test_loglik", loglik)
        #self.log("test_loglik_mean_std",ll_std)

        if batch_idx==0:
            self.plot_intermediate_steps(X.detach().clone(), actions.detach().clone(),t_X.detach().clone(),state0.detach().clone(), steps=2)

            N = 20
            X = X[0,:,:][None,].repeat(N,1,1)
            actions = actions[0,:,:][None,].repeat(N,1,1)


            state0 = self.startState(X[:,0,:], N)#self.normal.sample(sample_shape = (N, self.state_dim)).squeeze(-1).to(X)
            logpx = None #self.normal.log_prob(state0)

            Y_hat, state0, logpx, _, _= self(
                state0=state0, 
                X=X[:,:-1,:], 
                actions=actions[:,:-1,:], 
                times=t_X, 
                logpx=logpx,
                parent_forcing=False,
                addNoise = False,
                reverse =False
                )
            #Masked values for plotting
            X_mask = torch.full_like(X_all, np.nan)
            X_mask[:,mask_0,0] = X_all[:,mask_0,0]
            X_mask[:,mask_1,1] = X_all[:,mask_1,1]
            plot_trajectories_test( X.detach().clone(), Y_hat.detach().clone(), t_X.detach().clone(), actions.detach().clone(), self.dataset_name, chart_type = "test", logger=self.logger )
            plot_trajectories_test_simple( X.detach().clone(),X_mask.detach().clone() ,Y_hat.detach().clone(), t_X.detach().clone(), actions.detach().clone(), self.dataset_name, chart_type = "test_simple", logger=self.logger )
            plot_trajectories_test_all( X.detach().clone(), Y_hat.detach().clone(), t_X.detach().clone(), self.dataset_name, chart_type = "test2", logger=self.logger )
            
        
        #return {'test_loss':loss}
    def predict_step(self, batch, batch_idx, reverse=False, parent_forcing=False):
        X, actions, t_X, weights, state0 = batch
        print('PREDICT')
        print(X.shape)
        print(actions.shape)
        print(t_X.shape)
        print(weights.shape)
        print(state0.shape)
        state0 = state0 #self.startState(X[:,0,:], batch_size)#self.normal.sample(sample_shape = (batch_size, self.state_dim)).squeeze(-1).to(X)
        logpx = None #self.normal.log_prob(state0)

        #Run model until t:

        x= X[:,:-1,:].clone()
        #replace missing values with some flag value
        if torch.isnan(X).any():
            x[torch.isnan(x)] = 100

        Y_hat, states, logpx, _regTerm1, _regTerm2= self(
            state0=state0, 
            X=x, 
            actions=actions[:,:-1,:], 
            times=t_X, 
            logpx=logpx,
            parent_forcing=parent_forcing,
            reverse=reverse,
            addNoise = False,
            )

        return Y_hat, states
        
        #Obtain counterfactual data



    #maybe move to utils...
    def standard_normal_logprob(z):
        dim = z.size(-1)
        log_z = -0.5 * dim * log(2 * pi)
        return log_z - z.pow(2) / 2
    


    
    def compute_losses(self, Y, actions, t_X, s_hat, loss_func, weights):
        """
        Compyte loss based on the given key word loss_func
        Y: observations
        actions: 
        t_X: time points
        s_hat: predictions
        ignore missing observations in the loss function evaluation.
        """
        # Compute loss on observations, ignore missing values
        #print(weights)
        mask = weights #~torch.isnan(Y)
        missing_obs = (~mask).any()
        #end_mask = ~torch.isnan(Y[:,-1,:])
        if loss_func == "elbo":
            print('elbo loss not implemented')
            return None
        elif loss_func == "maxloglik":
            print("maxloglik not implemented yet")
        elif loss_func == "div":
            logpx = None
            state0 = s_hat[:,-1,:]
            
            #state0 =  s_hat#torch.cat((s_hat[:,-1,:-self.obs_dim], Y[:,-1,:]), dim=1)
            # in backward pass replace missing observations with forward approximations
            if missing_obs:
                state0[:,-self.obs_dim:][mask[:,-1]].data =Y[:,-1,:][mask[:,-1]]
                x =  Y #[:,:-1,:] #state[:,:i,:self.obs_dim]
                x[~mask] = s_hat[:,:, -self.obs_dim:][~mask].clone().detach()
                x = x[:,:-1,:]
            else:
                x= Y[:,:-1,:]

            _, state_hat, _logpx, _ , _= self(
                state0=state0, 
                X= x, 
                actions=actions, 
                times=t_X,  #:,:i]
                logpx=logpx,
                parent_forcing=False,
                addNoise = False,
                reverse =True
                )
            loss = nn.MSELoss()
            #relaxed version
            div = loss(torch.flip(state_hat[:,:,-self.obs_dim:], dims=[1])[mask],Y[mask])
            #strickt version
            #div = loss(torch.flip(state_hat, dims=[1]),s_hat )
            return div
        elif loss_func == "mse":
            loss = nn.MSELoss(reduce=False)
            if missing_obs:
                _loss = loss(s_hat[:,:,-self.obs_dim:][mask], Y[:,:,:][mask])
            else:
                _loss = loss(s_hat[:,:,-self.obs_dim:], Y[:,:,:])
            #_loss = loss(s_hat[:,:,-self.obs_dim:], Y[:,:,:])

            #for i in range(s_hat.size()[1]):
            #    _loss += loss(s_hat[:,i,-self.obs_dim:], Y[:,i,:])
            return _loss
        elif loss_func == "weighted_mse":
            Y= Y.clone()
            loss = nn.MSELoss(reduce=False)
            Y[~mask] = s_hat[:,:,-self.obs_dim:][~mask]
            _loss = (weights * (loss(s_hat[:,:,-self.obs_dim:], Y))) / weights.sum()
            return _loss

        else:
            print("given loss function {} did not match: mse, maxloglik or elbo".format(self.loss_func))
             


    def plot_intermediate_steps(self, X, actions, t_X, state0, steps):
        """Runs the system with steps number of intemediate steps and plots first trajectory"""
        N = steps
        t_size=t_X.shape[1]*N-(N-1)
        #create empty tensors and fill evry N instance with observation
        _X = torch.full((X.shape[0], t_size, X.shape[2]), np.nan)#, dtype=np.float32)
        _actions = torch.full((actions.shape[0], t_size, actions.shape[2]), np.nan)#, dtype=np.float32)
        _X[:,::N,:]=X
        _actions[:,::N,:]=actions
        #torch.repeat_interleave(X,N,dim=1)
        
        #torch.repeat_interleave(actions,N,dim=1)
        t = []
        for t0, t1 in zip(t_X[0,:-1],t_X[0,1:]):
            t.append(torch.linspace(t0, t1, N+1)[:-1])

        t.append(t_X[0,-1:])
        t = torch.cat(t,0).view((1,t_size))

        logpx = None 

        Y_hat, state, logpx, _, _= self(
            state0=state0, 
            X=_X, 
            actions=_actions, 
            times= t, 
            logpx=logpx,
            parent_forcing=False,
            addNoise = False,
            reverse =False
            )
        plot_trajectories( _X.detach().clone(), Y_hat.detach().clone(), t.detach().clone(), self.dataset_name , chart_type = "val_multistep", logger=self.logger )

    def logmeanexp(self, x, dim=None, keepdim=False):
        """Stable computation of log(mean(exp(x))"""

    
        if dim is None:
            x, dim = x.view(-1), 0
        x_max, _ = torch.max(x, dim, keepdim=True)
        x = x_max + torch.log(torch.mean(torch.exp(x - x_max), dim, keepdim=True))
        return x if keepdim else x.squeeze(dim)

    def compute_logLikelihood(self, batch):
        #torch.set_grad_enabled(True)
        X, actions, t_X, _ = batch
        #Take a subset
        X = X[:10,:,:]
        actions = actions[:10,:,:]
        t_X = t_X[:10, :]

        batch_size = X.size()[0]
        time_steps = X.size()[1]
        #states = torch.cat((state[:,:,:-self.obs_dim], Y) , dim=1).detach()
        N = 10 #expectation over N samples
        loglik = 0.
        X = X.repeat(N, 1, 1) #X[0,:,:][None,].repeat(N,1,1)
        actions = actions.repeat(N, 1, 1) #[0,:,:][None,].repeat(N,1,1)
        t_X = t_X.repeat(N, 1)
        X_true =X.clone()

        #Corrupted samples:
        missing_prop = 0.0
        samples=int(missing_prop*time_steps)

        step = (time_steps-15)/float(samples-1)
        
        _mask = [int(round(14+x*step)) for x in range(samples)]  #random.sample(range(5, time_steps), samples)
        print(_mask)
        mask = np.zeros(time_steps)
        mask[_mask] = 1
        mask=np.array(mask, dtype=bool)#.int()
        #print(mask)
        mask_true = np.ones(time_steps)
        mask_true[_mask] = 0
        mask_true= np.array(mask_true, dtype=bool)#.int()

        #for ii in mask:
        error =torch.tensor(np.random.choice([-0.3,0.3], samples),dtype=torch.float)
        X[:,mask,0]= X[:,mask,0] + error
        error =torch.tensor(np.random.choice([-0.3,0.3], samples),dtype=torch.float)
        X[:,mask,1]= X[:,mask,1] + error


        state0 = self.startState(X[:,0,:], N*batch_size)#self.normal.sample(sample_shape = (N, self.state_dim)).squeeze(-1).to(X)
        logpx = None #self.normal.log_prob(state0)

        Y_hat, state, logpx, _, _= self(
            state0=state0, 
            X=X[:,:-1,:], 
            actions=actions[:,:-1,:], 
            times=t_X, 
            logpx=logpx,
            parent_forcing=False,
            addNoise = False,
            reverse =False
            )
        Y_pred = Y_hat.clone()
        logpx= torch.zeros_like(X)
        px_0 =torch.zeros_like(X)
        for i in range(1,t_X.size()[1]):
            #logpx = torch.zeros(*X[:,0,:].shape[:-2], 1, device= self.device) #
            end_mask = ~torch.isnan(X[:,i,:])
            #mask =  torch.isnan(X[:,:i,:])

            state0 = state[:,i,:].clone()
            state0[:,-self.obs_dim:][end_mask] = X[:,i,:][end_mask].clone()

            #state0 = torch.cat((state[:,i,:-self.obs_dim], X[:,i,:]), dim=1) # state[:,i,:] #

            x =  X[:,:i,:].clone() #Y#[:,:-1,:] #state[:,:i,:self.obs_dim]
            # replace missing observations with forward predictions
            #x[mask] = state[:,:i, -self.obs_dim:][mask].clone()
            #x = x[:,:-1,:]

            #x =  X[:,:i,:] #state[:,:i,:self.obs_dim]
            _, state_hat, _logpx, _ , _= self(
                state0=state0, 
                X= x, 
                actions=actions[:,:i,:],#[:,:i,:], 
                times=t_X[:,:i+1],  #:,:i]
                logpx=logpx,
                parent_forcing=False,
                addNoise = False,
                reverse =True
                )

            logpx[:,i,:] = -_logpx[:,-1,:] 
            px_0[:,i,:] = self.normal_state.log_prob(state_hat[:,-1,-self.obs_dim:]-X[:,0,:])
            #loglik += torch.sum(_logpx, dim=1) + torch.sum(self.normal_state.log_prob(state_hat[:,-1,-self.obs_dim:]-X[:,0,:]), dim=-1, keepdim=True)


        #print(logpx.size())
        mcmc_loglik = torch.zeros(batch_size)
        mcmc_std = torch.zeros(batch_size)
        #print(mcmc_loglik.size())

        for i in range(batch_size):
            
            _div=torch.sum((torch.sum(logpx[i::batch_size,:,:],dim=0)/N))
            #print( _div.shape)
            
            _px = torch.sum(self.logmeanexp(px_0[i::batch_size,:,:],dim=0))
            #torch.sum(torch.log(torch.sum(px_0[i::batch_size,:,:],dim=0)/N))
            #print(_px.shape)
            mcmc_loglik[i] = torch.add(_div,_px) #torch.sum((torch.sum(logpx[i::batch_size,:,:],dim=0)/N),dim=-2) + torch.sum(torch.log(torch.sum(px_0[i::batch_size,:,:],dim=0)/N),dim=-2)
            mcmc_std[i] = torch.std(torch.sum(logpx[i::batch_size,:,:]+torch.sum(px_0[i::batch_size,:,:]) , dim=[1,2]))

        #PLOTTING
        X_corrupt = X.clone()
        X_corrupt[:,mask_true,:] = torch.tensor(np.nan,dtype=torch.float)
        X_true[:,mask,:] = torch.tensor(np.nan,dtype=torch.float)
        #print(X_corrupt[0,:,0])
        #print(X_true[0,:,0])
        plot_trajectories_with_logprobabilities(
            X_true[::batch_size,:,:].detach(), 
            X_corrupt[::batch_size,:,:].detach(),
            Y_pred[::batch_size,:,:].detach(), 
            t_X[::batch_size,:].detach(), 
            logpx[::batch_size,:,:].detach(), 
            px_0[::batch_size,:,:].detach(),
            dataset_name =  self.dataset_name, 
            chart_type = "loglikelihoods", 
            logger= self.logger
            )



        return -torch.mean(mcmc_loglik), torch.mean(mcmc_std)

    

                

    @classmethod
    def add_model_specific_args(cls, parent):
        parser = argparse.ArgumentParser(parents=[parent])
        parser.add_argument('--loss_func', type=str, default='mse', help= "loss function options: mse, elbo, maxloglik")
        parser.add_argument("--state_dim", type=int, default=5, help = "dimension of the embedding space") #10
        parser.add_argument("--num_blocks", type=int, default=1, help = "gate blocks in cnf")
        parser.add_argument("--lr", type=float, default=0.001) # 0.001
        #parser.add_argument("--lr_gamma", type=float, default=0.98)
        parser.add_argument("--reg1", type=float, default=0.0)
        parser.add_argument("--reg2", type=float, default=0.0)

        
        return parser
