import os
import sys
import numpy as np
import scipy.io
import argparse
from random import SystemRandom

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence

from lib.functions import *
from lib.utils import *

################################################################
###   Proposed methods                                       ###
################################################################
class RLNO_MI(nn.Module):
    def __init__(self, args, device, input_size2):
        super(RLNO_MI, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents
        self.gru_units = gru_units

        self.z0_operator_solver = DeepONet1(input_size1=rec_dims, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.transform_z0 = nn.Sequential(nn.Linear(rec_dims*2, 100), nn.Tanh(), nn.Linear(100, latents*2),).to(device)
        init_network_weights(self.transform_z0)

        self.GRU_update = GRU_unit(rec_dims, input_dim, n_units=gru_units, device=device).to(device)
        self.operator_solver = DeepONet_MI(input_size1=latents, input_size2=input_size2, input_size3=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False
    
    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        rec_dims = self.rec_dims
        n_traj_samples = args.n_traj_samples
        
        ## Operator-based RNN for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        u_batch = batch[2]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp, :]
        time_steps = batch[1][:, :encoder_ntp]
        
        if time_steps.shape[1] == 1:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            xi = data[:,0,:].unsqueeze(0)
            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_t, t_i = time_steps[:,-1] + 0.01,  time_steps[:,-1]

            latent_ys = []
            # Run NO backwards and combine the y(t) estimates using gating
            time_points_iter = reversed(range(0, encoder_ntp))

            for i in time_points_iter:
                Delta_t = t_i - prev_t
                y1 = prev_y
                y2 = Delta_t.unsqueeze(0).unsqueeze(-1)
                inc = self.z0_operator_solver(y1,y2)
                assert(not torch.isnan(inc).any())

                yi_no = inc
                xi = data[:,i,:].unsqueeze(0)
                yi, yi_std = self.GRU_update(yi_no, prev_std, xi)
                prev_y, prev_std = yi, yi_std
                prev_t, t_i = time_steps[:,i],  time_steps[:,i-1]
                latent_ys.append(yi)
        
            latent_ys = torch.stack(latent_ys, 1)
            last_yi, last_yi_std = yi, yi_std

        means_z0 = last_yi.reshape(1, n_traj, rec_dims)
        std_z0 = last_yi_std.reshape(1, n_traj, rec_dims)
        mean_z0, std_z0 = split_last_dim(self.transform_z0(torch.cat((means_z0, std_z0), -1)))
        std_z0 = std_z0.abs()
        first_point_mu, first_point_std = mean_z0, std_z0
        
        ## Operator for reconstruction and prediction
        means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
        sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
        first_point_enc = sample_standard_gaussian(means_z0, sigma_z0)
        first_point_std = first_point_std.abs()
        first_point_enc_aug = first_point_enc
        means_z0_aug = means_z0
        
        y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,n_tp,1)
        y2 = u_batch.unsqueeze(0).unsqueeze(-2).repeat(n_traj_samples,1,n_tp,1)
        y3 = t_batch.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
        sol_y = self.operator_solver(y1,y2,y3)
        
        pred_x = self.decoder(sol_y)
        all_extra_info = {"first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach()}

        pred_y, info = pred_x, all_extra_info

        
        #print("get_reconstruction done -- computing likelihood")
        obsrv_std = torch.Tensor([0.01]).to(device)
        z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_std[fp_std==0] = 1e-6
        fp_distr = Normal(fp_mu, fp_std)
        kldiv_z0 = kl_divergence(fp_distr, z0_prior)
        kldiv_z0 = torch.mean(kldiv_z0,(1,2))
        
        # Compute likelihood of all the points
        rec_likelihood = get_gaussian_likelihood(s_batch, pred_y, obsrv_std)
        mse = get_mse(s_batch, pred_y)
        pois_log_likelihood = torch.Tensor([0.]).to(device) ###

        ################################
        # Compute CE loss for binary classification on Physionet
        ce_loss = torch.Tensor([0.]).to(device) ###

        # IWAE loss
        kl_coef = 0.1
        loss = - torch.logsumexp(rec_likelihood -  kl_coef * kldiv_z0,0)
        if torch.isnan(loss):
            loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] =  torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if self.iftest:
            self.iftest = False
            y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,self.time_full.shape[-1],1)
            y2 = u_batch.unsqueeze(0).unsqueeze(-2).repeat(n_traj_samples,1,self.time_full.shape[-1],1)
            y3 = self.time_full.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
            sol_y = self.operator_solver(y1,y2,y3)
            pred_y = self.decoder(sol_y)    
        
        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            itr//num_batches, 
            test_res["loss"].detach(), test_res["likelihood"].detach(), 
            test_res["kl_first_p"], test_res["std_first_p"])
        
        logger.info("Experiment " + str(experimentID))
        logger.info(message)
        logger.info("KL coef: {}".format(kl_coef))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

        if "mse" in test_res:
            logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

        if "pois_likelihood" in test_res:
            logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

        if "ce_loss" in test_res:
            logger.info("CE loss: {}".format(test_res["ce_loss"]))
    
    
    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)
            
            return(pred_y)

    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)


class RLNO(nn.Module):
    def __init__(self, args, device):
        super(RLNO, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents
        self.gru_units = gru_units

        self.z0_operator_solver = DeepONet1(input_size1=rec_dims, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.transform_z0 = nn.Sequential(nn.Linear(rec_dims*2, 100), nn.Tanh(), nn.Linear(100, latents*2),).to(device)
        init_network_weights(self.transform_z0)

        self.GRU_update = GRU_unit(rec_dims, input_dim, n_units=gru_units, device=device).to(device)
        self.operator_solver = DeepONet1(input_size1=latents, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False
    
    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        rec_dims = self.rec_dims
        n_traj_samples = args.n_traj_samples
        
        ## Operator-based RNN for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp, :]
        time_steps = batch[1][:, :encoder_ntp]
        
        if time_steps.shape[1] == 1:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            xi = data[:,0,:].unsqueeze(0)
            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_t, t_i = time_steps[:,-1] + 0.01,  time_steps[:,-1]

            latent_ys = []
            # Run NO backwards and combine the y(t) estimates using gating
            time_points_iter = reversed(range(0, encoder_ntp))

            for i in time_points_iter:
                Delta_t = t_i - prev_t
                y1 = prev_y
                y2 = Delta_t.unsqueeze(0).unsqueeze(-1)
                inc = self.z0_operator_solver(y1,y2)
                assert(not torch.isnan(inc).any())

                yi_no = inc
                xi = data[:,i,:].unsqueeze(0)
                yi, yi_std = self.GRU_update(yi_no, prev_std, xi)
                prev_y, prev_std = yi, yi_std
                prev_t, t_i = time_steps[:,i],  time_steps[:,i-1]
                latent_ys.append(yi)
        
            latent_ys = torch.stack(latent_ys, 1)
            last_yi, last_yi_std = yi, yi_std

        means_z0 = last_yi.reshape(1, n_traj, rec_dims)
        std_z0 = last_yi_std.reshape(1, n_traj, rec_dims)
        mean_z0, std_z0 = split_last_dim(self.transform_z0(torch.cat((means_z0, std_z0), -1)))
        std_z0 = std_z0.abs()
        first_point_mu, first_point_std = mean_z0, std_z0
        
        ## Operator for reconstruction and prediction
        means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
        sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
        first_point_enc = sample_standard_gaussian(means_z0, sigma_z0)
        first_point_std = first_point_std.abs()
        first_point_enc_aug = first_point_enc
        means_z0_aug = means_z0
        
        y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,n_tp,1)
        y2 = t_batch.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
        sol_y = self.operator_solver(y1,y2)

        pred_x = self.decoder(sol_y)
        all_extra_info = {"first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach()}

        pred_y, info = pred_x, all_extra_info

        
        #print("get_reconstruction done -- computing likelihood")
        obsrv_std = torch.Tensor([0.01]).to(device)
        z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_std[fp_std==0] = 1e-6
        fp_distr = Normal(fp_mu, fp_std)
        kldiv_z0 = kl_divergence(fp_distr, z0_prior)
        kldiv_z0 = torch.mean(kldiv_z0,(1,2))
        
        # Compute likelihood of all the points
        rec_likelihood = get_gaussian_likelihood(s_batch, pred_y, obsrv_std)
        mse = get_mse(s_batch, pred_y)
        pois_log_likelihood = torch.Tensor([0.]).to(device) ###

        ################################
        # Compute CE loss for binary classification on Physionet
        ce_loss = torch.Tensor([0.]).to(device) ###

        # IWAE loss
        kl_coef = 0.1
        loss = - torch.logsumexp(rec_likelihood -  kl_coef * kldiv_z0,0)
        if torch.isnan(loss):
            loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] =  torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if self.iftest:
            self.iftest = False
            y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,self.time_full.shape[-1],1)
            y2 = self.time_full.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
            sol_y = self.operator_solver(y1,y2)
            pred_y = self.decoder(sol_y)            
        
        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            itr//num_batches, 
            test_res["loss"].detach(), test_res["likelihood"].detach(), 
            test_res["kl_first_p"], test_res["std_first_p"])
        
        logger.info("Experiment " + str(experimentID))
        logger.info(message)
        logger.info("KL coef: {}".format(kl_coef))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

        if "mse" in test_res:
            logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

        if "pois_likelihood" in test_res:
            logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

        if "ce_loss" in test_res:
            logger.info("CE loss: {}".format(test_res["ce_loss"]))
    
    
    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)
            
            return(pred_y)
    
    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)



################################################################
###   Ablation experiments                                   ###
################################################################        

################################################################
###   RNN encoder                                            ###
################################################################   
class LNO_Ab1(nn.Module):
    def __init__(self, args, device):
        super(LNO_Ab1, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents
        self.gru_units = gru_units
        
        self.z0_operator_solver = DeepONet1(input_size1=rec_dims, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.transform_z0 = nn.Sequential(nn.Linear(rec_dims*2, 100), nn.Tanh(), nn.Linear(100, latents*2),).to(device)
        init_network_weights(self.transform_z0)
        
        self.GRU_update = GRU_unit(rec_dims, input_dim, n_units=gru_units, device=device).to(device)
        self.operator_solver = DeepONet1(input_size1=latents, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False
    
    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        input_dim = self.input_dim
        rec_dims = self.rec_dims
        latents = self.latents
        gru_units = self.gru_units
        n_traj_samples = args.n_traj_samples
        
        ## RNN for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp]
        time_steps = batch[1][:, :encoder_ntp]

        if time_steps.shape[1] == 1:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            xi = data[:,0,:].unsqueeze(0)
            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_t, t_i = time_steps[:,-1] + 0.01,  time_steps[:,-1]

            latent_ys = []
            # Run NO backwards and combine the y(t) estimates using gating
            time_points_iter = reversed(range(0, encoder_ntp))

            for i in time_points_iter:
                xi = data[:,i,:].unsqueeze(0)
                yi, yi_std = self.GRU_update(prev_y, prev_std, xi)
                prev_y, prev_std = yi, yi_std
                prev_t, t_i = time_steps[:,i],  time_steps[:,i-1]
                latent_ys.append(yi)
        
            latent_ys = torch.stack(latent_ys, 1)
            last_yi, last_yi_std = yi, yi_std

        means_z0 = last_yi.reshape(1, n_traj, rec_dims)
        std_z0 = last_yi_std.reshape(1, n_traj, rec_dims)
        mean_z0, std_z0 = split_last_dim(self.transform_z0(torch.cat((means_z0, std_z0), -1)))
        std_z0 = std_z0.abs()
        first_point_mu, first_point_std = mean_z0, std_z0
        
        ## Operator for reconstruction and prediction
        means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
        sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
        first_point_enc = sample_standard_gaussian(means_z0, sigma_z0)
        first_point_std = first_point_std.abs()
        first_point_enc_aug = first_point_enc
        means_z0_aug = means_z0
        
        y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,n_tp,1)
        y2 = t_batch.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
        sol_y = self.operator_solver(y1,y2)

        pred_x = self.decoder(sol_y)
        all_extra_info = {"first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach()}

        pred_y, info = pred_x, all_extra_info

        
        #print("get_reconstruction done -- computing likelihood")
        obsrv_std = torch.Tensor([0.01]).to(device)
        z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_std[fp_std==0] = 1e-6
        fp_distr = Normal(fp_mu, fp_std)
        kldiv_z0 = kl_divergence(fp_distr, z0_prior)
        kldiv_z0 = torch.mean(kldiv_z0,(1,2))
        
        # Compute likelihood of all the points
        rec_likelihood = get_gaussian_likelihood(s_batch, pred_y, obsrv_std)
        mse = get_mse(s_batch, pred_y)
        pois_log_likelihood = torch.Tensor([0.]).to(device) ###

        ################################
        # Compute CE loss for binary classification on Physionet
        ce_loss = torch.Tensor([0.]).to(device) ###

        # IWAE loss
        kl_coef = 0.1
        loss = - torch.logsumexp(rec_likelihood -  kl_coef * kldiv_z0,0)
        if torch.isnan(loss):
            loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] =  torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if self.iftest:
            self.iftest = False
            y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,self.time_full.shape[-1],1)
            y2 = self.time_full.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
            sol_y = self.operator_solver(y1,y2)
            pred_y = self.decoder(sol_y)   
            
        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            itr//num_batches, 
            test_res["loss"].detach(), test_res["likelihood"].detach(), 
            test_res["kl_first_p"], test_res["std_first_p"])
        
        logger.info("Experiment " + str(experimentID))
        logger.info(message)
        logger.info("KL coef: {}".format(kl_coef))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

        if "mse" in test_res:
            logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

        if "pois_likelihood" in test_res:
            logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

        if "ce_loss" in test_res:
            logger.info("CE loss: {}".format(test_res["ce_loss"]))
    
    
    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)
            
            return(pred_y)

    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)

################################################################
###   GRUDecay encoder                                       ###
################################################################  
class LNO_Ab2(nn.Module):
    def __init__(self, args, device):
        super(LNO_Ab2, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents; self.gru_units = gru_units
        
        self.w_input_decay =  Parameter(torch.Tensor(1, int(input_dim))).to(device)
        self.b_input_decay =  Parameter(torch.Tensor(1, int(input_dim))).to(device)
        self.rnn_cell_enc = GRUCellExpDecay(
                            input_size = 2*input_dim,
                            input_size_for_decay = input_dim,
                            hidden_size = rec_dims, 
                            device = device).to(device)
        
        self.z0_net = nn.Sequential(nn.Linear(rec_dims, 100), nn.Tanh(), nn.Linear(100, latents*2),).to(device)
        init_network_weights(self.z0_net)
        
        self.operator_solver = DeepONet1(input_size1=latents, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False
    
    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        input_dim = self.input_dim
        rec_dims = self.rec_dims
        latents = self.latents
        gru_units = self.gru_units
        n_traj_samples = args.n_traj_samples
        
        ## RNN-decay for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp]
        time_steps = batch[1][:, :encoder_ntp]
        
        idx = [i for i in range(data.size(1)-1, -1, -1)]
        data = data[:, idx]
        delta_ts = time_steps[:,1:] - time_steps[:,:-1]
        idx = [i for i in range(delta_ts.size(1)-1, -1, -1)]
        delta_ts = delta_ts[:, idx]
        zero_delta_t = torch.zeros(delta_ts.shape[0],1).to(device)
        delta_ts = torch.cat((delta_ts, zero_delta_t), axis=1).unsqueeze(-1)
        
        input_decay_params = (self.w_input_decay, self.b_input_decay)
        hidden_state, _ = run_rnn(data, delta_ts, cell = self.rnn_cell_enc, input_decay_params = input_decay_params)
        assert(not torch.isnan(hidden_state).any())
        
        z0_mean, z0_std = split_last_dim(self.z0_net(hidden_state.unsqueeze(0)))
        z0_std = z0_std.abs()
        means_z0 = z0_mean.repeat(n_traj_samples, 1, 1)
        sigma_z0 = z0_std.repeat(n_traj_samples, 1, 1)
        z0_sample = sample_standard_gaussian(means_z0, sigma_z0)
        
        first_point_enc_aug = z0_sample
        y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,n_tp,1)
        y2 = t_batch.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
        sol_y = self.operator_solver(y1,y2)
        pred_x = self.decoder(sol_y)
        
        all_extra_info = {"first_point": (z0_mean, z0_std, z0_sample), "latent_traj": sol_y.detach()}

        pred_y, info = pred_x, all_extra_info

        #print("get_reconstruction done -- computing likelihood")
        obsrv_std = torch.Tensor([0.01]).to(device)
        z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_std[fp_std==0] = 1e-6
        fp_distr = Normal(fp_mu, fp_std)
        kldiv_z0 = kl_divergence(fp_distr, z0_prior)
        kldiv_z0 = torch.mean(kldiv_z0,(1,2))
        
        # Compute likelihood of all the points
        rec_likelihood = get_gaussian_likelihood(s_batch, pred_y, obsrv_std)
        mse = get_mse(s_batch, pred_y)
        pois_log_likelihood = torch.Tensor([0.]).to(device) ###
        
        ################################
        # Compute CE loss for binary classification on Physionet
        ce_loss = torch.Tensor([0.]).to(device) ###
    
        # IWAE loss
        kl_coef = 0.1
        loss = - torch.logsumexp(rec_likelihood -  kl_coef * kldiv_z0,0)
        if torch.isnan(loss):
            loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] =  torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if self.iftest:
            self.iftest = False
            y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,self.time_full.shape[-1],1)
            y2 = self.time_full.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
            sol_y = self.operator_solver(y1,y2)
            pred_y = self.decoder(sol_y)   
        
        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            itr//num_batches, 
            test_res["loss"].detach(), test_res["likelihood"].detach(), 
            test_res["kl_first_p"], test_res["std_first_p"])
        
        logger.info("Experiment " + str(experimentID))
        logger.info(message)
        logger.info("KL coef: {}".format(kl_coef))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

        if "mse" in test_res:
            logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

        if "pois_likelihood" in test_res:
            logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

        if "ce_loss" in test_res:
            logger.info("CE loss: {}".format(test_res["ce_loss"]))
    
    
    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)
            
            return(pred_y)   

    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)
        

################################################################
###   ODE-RNN encoder                                        ###
################################################################  
class LNO_Ab3(nn.Module):
    def __init__(self, args, device):
        super(LNO_Ab3, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents; self.gru_units = gru_units
        
        ode_func_net = create_net(rec_dims, rec_dims, n_layers = 1, n_units = 100, nonlinear = nn.Tanh)
        rec_ode_func = ODEFunc(input_dim = 2*input_dim, latent_dim = rec_dims, 
                               ode_func_net = ode_func_net, device = device).to(device)
        self.z0_diffeq_solver = DiffeqSolver(2*input_dim, rec_ode_func, "euler", 
                                        odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
        #self.z0_operator_solver = DeepONet1(input_size1=rec_dims, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.transform_z0 = nn.Sequential(nn.Linear(rec_dims*2, 100), nn.Tanh(), nn.Linear(100, latents*2),).to(device)
        init_network_weights(self.transform_z0)

        self.GRU_update = GRU_unit(rec_dims, input_dim, n_units=gru_units, device=device).to(device)
        self.operator_solver = DeepONet1(input_size1=latents, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False
    
    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        input_dim = self.input_dim
        rec_dims = self.rec_dims
        latents = self.latents
        gru_units = self.gru_units
        n_traj_samples = args.n_traj_samples
        
        ## ODE-based RNN for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp]
        time_steps = batch[1][:, :encoder_ntp]
        
        if time_steps.shape[1] == 1:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            xi = data[:,0,:].unsqueeze(0)
            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_t, t_i = time_steps[:,-1] + 0.01,  time_steps[:,-1]
            minimum_step = min(time_steps[:,-1] - time_steps[:,0]) / 50
            
            latent_ys = []
            # Run ODE backwards and combine the y(t) estimates using gating
            time_points_iter = reversed(range(0, encoder_ntp))
            
            for i in time_points_iter:
                yi_ode = torch.zeros(1, n_traj, args.rec_dims).to(device)
                for j in range(len(prev_t)):
                    if (prev_t[j] - t_i[j]) < minimum_step:
                        time_points = torch.stack((prev_t[j], t_i[j]))
                        inc = self.z0_diffeq_solver.ode_func(prev_t[j], prev_y[:,j:j+1]) * (t_i[j] - prev_t[j])
                        assert(not torch.isnan(inc).any())
                        ode_sol = prev_y[:,j:j+1] + inc
                        ode_sol = torch.stack((prev_y[:,j:j+1], ode_sol), 2).to(device)
                        assert(not torch.isnan(ode_sol).any())
                    else:
                        n_intermediate_tp = max(2, ((prev_t[j] - t_i[j]) / minimum_step).int())
                        time_points = linspace_vector(prev_t[j], t_i[j], n_intermediate_tp).to(device)
                        ode_sol = self.z0_diffeq_solver(prev_y[:,j:j+1], time_points)
                        assert(not torch.isnan(ode_sol).any())
                    yi_ode[:,j:j+1] = ode_sol[:, :, -1, :]
                
                xi = data[:,i,:].unsqueeze(0)
                yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)
                prev_y, prev_std = yi, yi_std
                prev_t, t_i = time_steps[:,i],  time_steps[:,i-1]
                latent_ys.append(yi)
            
            latent_ys = torch.stack(latent_ys, 1)
            last_yi, last_yi_std = yi, yi_std            
        
        means_z0 = last_yi.reshape(1, n_traj, rec_dims)
        std_z0 = last_yi_std.reshape(1, n_traj, rec_dims)
        mean_z0, std_z0 = split_last_dim(self.transform_z0(torch.cat((means_z0, std_z0), -1)))
        std_z0 = std_z0.abs()
        first_point_mu, first_point_std = mean_z0, std_z0
        
        ## Operator for reconstruction and prediction
        means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
        sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
        first_point_enc = sample_standard_gaussian(means_z0, sigma_z0)
        first_point_std = first_point_std.abs()
        first_point_enc_aug = first_point_enc
        means_z0_aug = means_z0
        
        y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,n_tp,1)
        y2 = t_batch.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
        sol_y = self.operator_solver(y1,y2)

        pred_x = self.decoder(sol_y)
        all_extra_info = {"first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach()}

        pred_y, info = pred_x, all_extra_info

        
        #print("get_reconstruction done -- computing likelihood")
        obsrv_std = torch.Tensor([0.01]).to(device)
        z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_std[fp_std==0] = 1e-6
        fp_distr = Normal(fp_mu, fp_std)
        kldiv_z0 = kl_divergence(fp_distr, z0_prior)
        kldiv_z0 = torch.mean(kldiv_z0,(1,2))
        
        # Compute likelihood of all the points
        rec_likelihood = get_gaussian_likelihood(s_batch, pred_y, obsrv_std)
        mse = get_mse(s_batch, pred_y)
        pois_log_likelihood = torch.Tensor([0.]).to(device) ###

        ################################
        # Compute CE loss for binary classification on Physionet
        ce_loss = torch.Tensor([0.]).to(device) ###

        # IWAE loss
        kl_coef = 0.1
        loss = - torch.logsumexp(rec_likelihood -  kl_coef * kldiv_z0,0)
        if torch.isnan(loss):
            loss = - torch.mean(rec_likelihood - kl_coef * kldiv_z0,0)

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] =  torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if self.iftest:
            self.iftest = False
            y1 = first_point_enc_aug.unsqueeze(-2).repeat(1,1,self.time_full.shape[-1],1)
            y2 = self.time_full.unsqueeze(0).unsqueeze(-1).repeat(n_traj_samples,1,1,1)
            sol_y = self.operator_solver(y1,y2)
            pred_y = self.decoder(sol_y)   
        
        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            itr//num_batches, 
            test_res["loss"].detach(), test_res["likelihood"].detach(), 
            test_res["kl_first_p"], test_res["std_first_p"])
        
        logger.info("Experiment " + str(experimentID))
        logger.info(message)
        logger.info("KL coef: {}".format(kl_coef))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))

        if "mse" in test_res:
            logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

        if "pois_likelihood" in test_res:
            logger.info("Poisson likelihood: {}".format(test_res["pois_likelihood"]))

        if "ce_loss" in test_res:
            logger.info("CE loss: {}".format(test_res["ce_loss"]))
    
    
    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)
            
            return(pred_y)  
        
    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)        
        

################################################################
###   ELBO --> MSE                                           ###
################################################################  
class LNO_Ab4(nn.Module):
    def __init__(self, args, device):
        super(LNO_Ab4, self).__init__()
        
        self.device = device
        self.args = args
        input_dim = args.input_dim
        rec_dims = args.rec_dims
        latents = args.latents
        gru_units = args.gru_units
        self.input_dim = input_dim; self.rec_dims = rec_dims; self.latents = latents; self.gru_units = gru_units

        self.z0_operator_solver = DeepONet1(input_size1=rec_dims, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.transform_z0 = nn.Sequential(nn.Linear(rec_dims, 100), nn.Tanh(), nn.Linear(100, latents),).to(device)
        init_network_weights(self.transform_z0)
        
        self.GRU_update = GRU_unit(rec_dims, input_dim, n_units=gru_units, device=device).to(device)
        self.operator_solver = DeepONet1(input_size1=latents, input_size2=1, hidden_size=64, p=32, device=device).to(device)
        self.decoder = Decoder(latents, input_dim).to(device)
        self.iftest = False

    def compute_all_losses(self, batch, kl_coef = 1.):
        device = self.device
        args = self.args
        input_dim = self.input_dim
        rec_dims = self.rec_dims
        latents = self.latents
        gru_units = self.gru_units
        
        ## Operator-based RNN for encoder 
        s_batch = batch[0]
        t_batch = batch[1]
        n_traj, n_tp, n_dims = s_batch.size()
        
        encoder_ntp = min(args.rec_len, n_tp)
        data = batch[0][:, :encoder_ntp]
        time_steps = batch[1][:, :encoder_ntp]
        
        if time_steps.shape[1] == 1:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            xi = data[:,0,:].unsqueeze(0)
            last_yi, last_yi_std = self.GRU_update(prev_y, prev_std, xi)
            extra_info = None
        else:
            prev_y = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_std = torch.zeros((1, n_traj, rec_dims)).to(device)
            prev_t, t_i = time_steps[:,-1] + 0.01,  time_steps[:,-1]
            
            latent_ys = []
            # Run NO backwards and combine the y(t) estimates using gating
            time_points_iter = reversed(range(0, encoder_ntp))

            for i in time_points_iter:
                Delta_t = t_i - prev_t
                y1 = prev_y
                y2 = Delta_t.unsqueeze(0).unsqueeze(-1)
                inc = self.z0_operator_solver(y1,y2)
                assert(not torch.isnan(inc).any())

                yi_no = inc
                xi = data[:,i,:].unsqueeze(0)
                yi, yi_std = self.GRU_update(yi_no, prev_std, xi)
                prev_y = yi
                prev_t, t_i = time_steps[:,i],  time_steps[:,i-1]
                latent_ys.append(yi)
            
            latent_ys = torch.stack(latent_ys, 1)
            last_yi = yi
        
        means_z0 = last_yi.reshape(1, n_traj, rec_dims)
        mean_z0 = self.transform_z0(means_z0)
        
        ## Operator for reconstruction and prediction
        y1 = mean_z0.squeeze(0).unsqueeze(-2).repeat(1,n_tp,1)
        y2 = t_batch.unsqueeze(-1)
        sol_y = self.operator_solver(y1,y2)
        pred_y = self.decoder(sol_y)
        
        mse = (s_batch - pred_y)**2
        
        results = {}
        results["loss"] = torch.mean(mse)

        if self.iftest:
            self.iftest = False
            y1 = mean_z0.squeeze(0).unsqueeze(-2).repeat(1,self.time_full.shape[-1],1)
            y2 = self.time_full.unsqueeze(-1)
            sol_y = self.operator_solver(y1,y2)
            pred_y = self.decoder(sol_y)   

        return(results, pred_y)
    
    def TestInfo(self, experimentID, test_dataset, train_res, itr, num_batches, kl_coef, logger):
        test_res, pred_y = self.compute_all_losses(test_dataset)
        
        logger.info("Experiment " + str(experimentID))
        logger.info("Epoch {:04d}".format(itr//num_batches))
        logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
        logger.info("Test MSE: {:.4f}".format(test_res["loss"]))

    def test(self, batch):
        with torch.no_grad():
            test_res, pred_y = self.compute_all_losses(batch)

            return(pred_y)        
        
    def test_full(self, batch, time_full):
        with torch.no_grad():
            self.iftest = True
            self.time_full = time_full
            test_res, pred_y = self.compute_all_losses(batch)
        
        return(pred_y)        


        
        
        
        
        
        
        