# Modeling Irregular Time Series with Continuous Recurrent Units (CRUs)
# Copyright (c) 2022 Robert Bosch GmbH
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# This source code is derived from Pytorch RKN Implementation (https://github.com/ALRhub/rkn_share)
# Copyright (c) 2021 Philipp Becker (Autonomous Learning Robots Lab @ KIT)
# licensed under MIT License
# cf. 3rd-party-licenses.txt file in the root directory of this source tree.


import torch
import numpy as np
import time as t
from datetime import datetime
import os
from typing import Tuple
from torch.utils.tensorboard import SummaryWriter
from lib.utils import TimeDistributed, log_to_tensorboard, make_dir, \
    compute_physionet_intermediate, compute_mimic_intermediate, \
    mimic_classification_save_intermediates, compute_auprc
from lib.encoder import Encoder
from lib.decoder import SplitDiagGaussianDecoder, BernoulliDecoder
from lib.CRULayer import CRULayer
from lib.CRUCell import var_activation, var_activation_inverse
from lib.losses import rmse, mse, GaussianNegLogLik, bernoulli_nll, vae_loss, mae
from lib.data_utils import  align_output_and_target, adjust_obs_for_extrapolation
#import lib.custom_gru as custom_gru
#import lib.peann_gru as peann_gru
#import lib.cph_gru as cph_gru
#import lib.learn_dynamics_gru as learn_dynamics_gru
#import lib.learn_discrete_dynamics_gru as learn_dynamics_gru
#from pycox.models.loss import nll_pc_hazard_loss
import pdb
import math

optim = torch.optim
nn = torch.nn
F = nn.functional

class multiTimeAttention(nn.Module):
    
    def __init__(self, input_dim, nhidden=16, 
                 embed_time=16, num_heads=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = nhidden
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(input_dim*num_heads, nhidden)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        dim = value.size(-1)
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(d_k)
        scores = scores.unsqueeze(-1).repeat_interleave(dim, dim=-1)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -2)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.unsqueeze(-3), -2), p_attn
    
    
    def forward(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        batch, seq_len, dim = value.size()
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        query, key = [l(x).view(x.size(0), -1, self.h, self.embed_time_k).transpose(1, 2)
                      for l, x in zip(self.linears, (query, key))]
        x, _ = self.attention(query, key, value, mask, dropout)
        x = x.transpose(1, 2).contiguous() \
             .view(batch, -1, self.h * dim)
        return self.linears[-1](x)


class enc_mtan_rnn(nn.Module):
    def __init__(self, input_dim, query, latent_dim=2, nhidden=16, 
                 embed_time=16, num_heads=1, learn_emb=False, device='cuda'):
        super(enc_mtan_rnn, self).__init__()
        self.embed_time = embed_time
        self.dim = input_dim
        self.device = device
        self.nhidden = nhidden
        self.query = query
        self.learn_emb = learn_emb
        self.att = multiTimeAttention(2*input_dim, nhidden, embed_time, num_heads)
        self.gru_rnn = nn.GRU(nhidden, nhidden, bidirectional=True, batch_first=True)
        self.hiddens_to_z0 = nn.Sequential(
            nn.Linear(2*nhidden, 20),
            nn.ReLU(),
            nn.Linear(20, latent_dim * 2))
        if learn_emb:
            self.periodic = nn.Linear(1, embed_time-1)
            self.linear = nn.Linear(1, 1)
        
    
    def learn_time_embedding(self, tt):
        tt = tt.to(self.device)
        tt = tt.unsqueeze(-1)
        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        return torch.cat([out1, out2], -1)
    
    def fixed_time_embedding(self, pos):
        d_model=self.embed_time
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model)
        position = 48.*pos.unsqueeze(2)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(np.log(10.0) / d_model))
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe.double()
       
    def forward(self, x, time_steps):
        time_steps = time_steps.cpu()
        mask = x[:, :, self.dim:]
        mask = torch.cat((mask, mask), 2)
        if self.learn_emb:
            key = self.learn_time_embedding(time_steps).to(self.device)
            query = self.learn_time_embedding(self.query.unsqueeze(0)).to(self.device)
        else:
            key = self.fixed_time_embedding(time_steps).to(self.device)
            query = self.fixed_time_embedding(self.query.unsqueeze(0)).to(self.device)
        out = self.att(query, key, x, mask)
        out, _ = self.gru_rnn(out)
        out = self.hiddens_to_z0(out)
        return out

class dec_mtan_rnn(nn.Module):
 
    def __init__(self, input_dim, query, latent_dim=2, nhidden=16, 
                 embed_time=16, num_heads=1, learn_emb=False, device='cuda',
                 classification_head=False, n_classify_tgts=113):
        super(dec_mtan_rnn, self).__init__()
        self.embed_time = embed_time
        self.dim = input_dim
        self.device = device
        self.nhidden = nhidden
        self.query = query
        self.learn_emb = learn_emb
        self.att = multiTimeAttention(2*nhidden, 2*nhidden, embed_time, num_heads)
        self.gru_rnn = nn.GRU(latent_dim, nhidden, bidirectional=True, batch_first=True)    
        self.z0_to_obs = nn.Sequential(
            nn.Linear(2*nhidden, 20),
            nn.ReLU(),
            nn.Linear(20, input_dim))
        if learn_emb:
            self.periodic = nn.Linear(1, embed_time-1)
            self.linear = nn.Linear(1, 1)
        self.classification_head = classification_head
        self.n_classify_tgts = n_classify_tgts
        if classification_head:
            self.z0_to_clf = nn.Sequential(
                nn.Linear(2*nhidden, 20),
                nn.ReLU(),
                nn.Linear(20, self.n_classify_tgts))
        
        
    def learn_time_embedding(self, tt):
        tt = tt.to(self.device)
        tt = tt.unsqueeze(-1)
        out2 = torch.sin(self.periodic(tt))
        out1 = self.linear(tt)
        return torch.cat([out1, out2], -1)
        
        
    def fixed_time_embedding(self, pos):
        d_model = self.embed_time
        pe = torch.zeros(pos.shape[0], pos.shape[1], d_model)
        position = 48.*pos.unsqueeze(2)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(np.log(10.0) / d_model))
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return pe.double()
       
    def forward(self, z, time_steps):
        out, _ = self.gru_rnn(z)
        time_steps = time_steps.cpu()
        if self.learn_emb:
            query = self.learn_time_embedding(time_steps).to(self.device)
            key = self.learn_time_embedding(self.query.unsqueeze(0)).to(self.device)
        else:
            query = self.fixed_time_embedding(time_steps).to(self.device)
            key = self.fixed_time_embedding(self.query.unsqueeze(0)).to(self.device)
        out = self.att(query, key, out)
        if self.classification_head:
            clf_out = self.z0_to_clf(out)
        out = self.z0_to_obs(out)
        if self.classification_head:
            return out, clf_out
        else:
            return out


class create_classifier(nn.Module):
 
    def __init__(self, latent_dim=10, nhidden=16, N=2):
        super(create_classifier, self).__init__()
        self.gru_rnn = nn.GRU(latent_dim, nhidden, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(nhidden, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, N))
       
    def forward(self, z):
        _, out = self.gru_rnn(z)
        return self.classifier(out.squeeze(0))


# taken from https://github.com/ALRhub/rkn_share/ and modified
class mTAND(nn.Module):

    # taken from https://github.com/ALRhub/rkn_share/ and modified
    def __init__(self, target_dim: int, lsd: int, args, use_cuda_if_available: bool = True, bernoulli_output: bool = False):
        """
        :param target_dim: output dimension
        :param lsd: latent state dimension
        :param args: parsed arguments
        :param use_cuda_if_available: if to use cuda or cpu
        :param use_bernoulli_output: if to use a convolutional decoder (for image data)
        """
        super().__init__()
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() and use_cuda_if_available else "cpu")

        self._lsd = lsd
        '''
        if self._lsd % 2 == 0:
            self._lod = int(self._lsd / 2) 
        else:
            raise Exception('Latent state dimension must be even number.')
        '''
        self.args = args

        # parameters TODO: Make configurable
        #self._enc_out_normalization = "pre"
        #self._initial_state_variance = 10.0
        self._learning_rate = self.args.lr
        self.bernoulli_output = bernoulli_output

        # mTAND specific
        self.rec_hidden = self._lsd # a hyper param
        #self.gen_hidden = 20
        self.learn_emb = True
        self.embed_time_dim = self.args.mTAND_time_embed_dim
        self.num_ref_points = self.args.mTAND_num_ref_points
        self.dec_num_heads = 1
        self.k_iwae = 5
        self.alpha = 10
        self.target_dim = target_dim
        self.args.std = 0.01
        self.args.norm = True
        #self.args.mtan_max_time_limit = 1.0#48.

        #print('Reference points sampled: {}'.format(torch.linspace(0.0, 1.0, self.num_ref_points)))
        self.rec = enc_mtan_rnn(target_dim, 
                torch.linspace(0.0, 1.0, self.num_ref_points).double().to(self._device),
                self._lsd, self.rec_hidden, self.embed_time_dim, 
                learn_emb=self.learn_emb).double().to(self._device)
        if args.task == 'classification':
            self.dec = dec_mtan_rnn(
                target_dim, 
                torch.linspace(0.0, 1.0, self.num_ref_points).double(),
                latent_dim=self._lsd,
                learn_emb=self.learn_emb, 
                num_heads=self.dec_num_heads,
                classification_head=True, n_classify_tgts=113).double().to(self._device)
        else:
            self.dec = dec_mtan_rnn(
                target_dim, 
                torch.linspace(0.0, 1.0, self.num_ref_points).double().to(self._device),
                latent_dim=self._lsd,
                learn_emb=self.learn_emb, 
                num_heads=self.dec_num_heads).double().to(self._device)

        # params and optimizer
        self._params = list(self.rec.parameters())
        self._params += list(self.dec.parameters())
        #self._params += list(self.output_mu.parameters())
        #self._params += list(self.output_sd.parameters())
        if self.args.task == 'classification':
            #self.classifier = create_classifier(latent_dim=self._lsd)
            #self._params += list(self.classifier.parameters())
            self.classifier_loss = nn.CrossEntropyLoss()

        self._optimizer = optim.Adam(self._params, lr=self.args.lr)
        self._shuffle_rng = np.random.RandomState(
            42)  # rng for shuffling batches

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_enc_hidden_layers(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for encoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_dec_hidden_layers_mean(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for mean decoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_dec_hidden_layers_var(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for variance decoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError
    
    # taken from https://github.com/ALRhub/rkn_share/ and modified
    def forward(self, obs_batch: torch.Tensor, time_points: torch.Tensor = None, obs_valid: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Single forward pass on a batch
        :param obs_batch: batch of observation sequences
        :param time_points: timestamps of observations
        :param obs_valid: boolean if timestamp contains valid observation 
        """
        #print('time dim len: {}'.format(obs_batch.shape[1]))
        pred_y = None
        batch_len = obs_batch.shape[0]
        #out = self.rec(torch.cat([obs_batch, obs_valid.double()], 2), time_points)
        out = self.rec(torch.cat([obs_batch, obs_valid], 2), time_points)
        qz0_mean, qz0_logvar = out[:, :, :self._lsd], out[:, :, self._lsd:]
        epsilon = torch.randn(self.k_iwae, qz0_mean.shape[0], qz0_mean.shape[1], qz0_mean.shape[2]).to(self._device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
        z0 = z0.view(-1, qz0_mean.shape[1], qz0_mean.shape[2])
        if self.args.task == 'classification':
            pred_x, pred_y = self.dec(z0, time_points[None, :, :].repeat(self.k_iwae, 1, 1).view(-1, time_points.shape[1]))
            pred_y = pred_y.view(self.k_iwae, batch_len, pred_y.shape[1], pred_y.shape[2])
        else:
            pred_x = self.dec(z0, time_points[None, :, :].repeat(self.k_iwae, 1, 1).view(-1, time_points.shape[1]))
        pred_x = pred_x.view(self.k_iwae, batch_len, pred_x.shape[1], pred_x.shape[2])
        train_batch = torch.cat([obs_batch, obs_valid, time_points[...,None]], 2)
        '''
        logpx, analytic_kl = vae_loss(self.target_dim, train_batch, qz0_mean, qz0_logvar, pred_x, self.args, self._device)
        kl_coef = 1.0
        recon_loss = -(torch.logsumexp(logpx - kl_coef * analytic_kl, dim=0).mean(0) - np.log(self.k_iwae))
        '''


        # output an image
        if self.bernoulli_output:
            out_mean = self._dec(post_mean)
            out_var = None
        # output prediction for the next time step
        elif self.args.task == 'one_step_ahead_prediction':
            out_mean, out_var = self._dec(
                prior_mean, torch.cat(prior_cov, dim=-1))

        # output filtered observation
        else:
            '''
            out_mean, out_var = self._dec(
                post_mean, torch.cat(post_cov, dim=-1))
            '''
            out_mean = pred_x.mean(axis=0)
            out_var = pred_x.var(axis=0)
            #out_var = 1e-6 * torch.ones_like(outputs)

        intermediates = {
            'qz0_mean': qz0_mean, 
            'qz0_logvar': qz0_logvar,
            'pred_x': pred_x,
            'pred_y': pred_y
            #'train_batch': train_batch
            #'surv_outputs': surv_outputs
            #'post_mean': post_mean,
            #'post_cov': post_cov,
            #'prior_mean': prior_mean,
            #'prior_cov': prior_cov,
            #'kalman_gain': kalman_gain,
            #'y': y,
            #'y_var': y_var
        }

        return out_mean, out_var, intermediates

    # new code component
    def interpolation(self, data, track_gradient=True):
        """Computes loss on interpolation task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        if self.bernoulli_output:
            obs, truth, obs_valid, obs_times, mask_truth = [
                j.to(self._device) for j in data]
            mask_obs = None
        else:
            obs, truth, obs_valid, obs_times, mask_truth, mask_obs = [
                j.to(self._device) for j in data]

        obs_times = self.args.ts * obs_times

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)

            if self.bernoulli_output:
                loss = bernoulli_nll(truth, output_mean, uint8_targets=False)
                mask_imput = (~obs_valid[...,None, None, None]) * mask_truth
                imput_loss = np.nan #TODO: compute bernoulli loss on imputed points
                imput_mse = mse(truth.flatten(start_dim=2), output_mean.flatten(start_dim=2), mask=mask_imput.flatten(start_dim=2))

            else:
                loss = GaussianNegLogLik(
                    output_mean, truth, output_var, mask=mask_truth)
                '''
                surv_loss = nll_pc_hazard_loss(intermediates['surv_outputs']['surv_phi'],
                            intermediates['surv_outputs']['surv_idx_durations'],
                            intermediates['surv_outputs']['surv_events'],
                            intermediates['surv_outputs']['surv_interval_frac'])
                intermediates['surv_loss'] = surv_loss.item()
                intermediates['output_loss'] = loss.item()
                loss = loss + self.args.peann_gamma * surv_loss
                '''
                # compute metric on imputed points only
                mask_imput = (~obs_valid[...,None]) * mask_truth
                imput_loss = GaussianNegLogLik(output_mean, truth, output_var, mask=mask_imput)
                imput_mse = mse(truth, output_mean, mask=mask_imput)
        
        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse

    def next_obs_prediction(self, data, epoch, track_gradient=True):
        """Computes loss on extrapolatio task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        if numeric_event_ids.ndim > 1:
            assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
            numeric_event_ids = numeric_event_ids[0,:]

        obs, obs_valid_extrap, obs_valid = adjust_obs_for_extrapolation(self.args.dataset, obs, obs_valid, mask_obs, obs_times)
        cumsum_targets = ((obs_valid != mask_truth).sum(-1).cumsum(-1) > 0).cumsum(axis=1)
        next_pred_time_mask = cumsum_targets == 1
        remove_multi_step_targets = cumsum_targets > 1
        mask_truth = torch.where(~remove_multi_step_targets[...,None], mask_truth, 0.0)
        truth = torch.where(~remove_multi_step_targets[...,None], truth, 0.0)

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            logpx, analytic_kl = vae_loss(len(numeric_event_ids), 
                torch.cat([truth[:,:,numeric_event_ids], mask_truth[:,:,numeric_event_ids]], dim=-1),
                intermediates['qz0_mean'], intermediates['qz0_logvar'], intermediates['pred_x'][..., numeric_event_ids],
                self.args, self._device)

            #kl_coef = 1.0
            #'''
            wait_until_kl_inc = 10#10
            if epoch < wait_until_kl_inc:
                kl_coef = 0.
            else:
                kl_coef = (1-0.99 ** (epoch - wait_until_kl_inc))
            #'''
            loss = -(torch.logsumexp(logpx - kl_coef * analytic_kl, dim=0).mean(0) - np.log(self.k_iwae))

            # this is new
            '''
            mse_loss = mse(truth[..., numeric_event_ids], intermediates['pred_x'][..., numeric_event_ids].mean(axis=0),
                        mask_truth[..., numeric_event_ids])
            loss = loss + self.alpha * mse_loss
            #loss = mse_loss
            '''

            # compute metric on imputed points only
            mask_imput = (~obs_valid) * mask_truth
            '''
            imput_loss = mse(
                truth, output_mean, mask=mask_imput)
            '''
            #'''
            # remove samples with no extrapolation
            zero_extrap_sequences = mask_imput[..., numeric_event_ids].view(mask_imput.shape[0], -1).sum(axis=1) == 0
            mask_imput_nz = mask_imput[~zero_extrap_sequences][..., numeric_event_ids]
            truth_imput_nz = truth[~zero_extrap_sequences][..., numeric_event_ids]
            imput_logpx, _ = vae_loss(len(numeric_event_ids), torch.cat([truth_imput_nz, mask_imput_nz], dim=-1), 
                intermediates['qz0_mean'][~zero_extrap_sequences],
                intermediates['qz0_logvar'][~zero_extrap_sequences],
                intermediates['pred_x'][:, ~zero_extrap_sequences][..., numeric_event_ids],
                self.args, self._device)

            imput_loss = -torch.logsumexp(imput_logpx, dim=0).mean(0)#.detach().item()
            imput_mse = mse(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], 
                mask=mask_imput[..., numeric_event_ids])#.detach().item()
            imput_mae = mae(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], 
                mask=mask_imput[..., numeric_event_ids])#.detach().item()

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, numeric_event_ids, imput_mae

    # new code component
    def extrapolation(self, data, epoch, track_gradient=True):
        """Computes loss on extrapolation task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        if numeric_event_ids.ndim > 1:
            assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
            numeric_event_ids = numeric_event_ids[0,:]

        obs, obs_valid_extrap, obs_valid = adjust_obs_for_extrapolation(self.args.dataset, obs, obs_valid, mask_obs, obs_times)
        '''
        mask_obs = mask_truth.clone()
        extrapolation_mask = (obs_times[:,:,None].repeat(1,1,obs.shape[-1]) <= 0.0) 
        obs_valid = mask_truth.bool().clone() * extrapolation_mask
        obs = torch.where(obs_valid, obs, 0.)
        assert obs_times[(obs != truth).any(axis=-1)].min().item() > 0.0, \
            "trying to predict before an apoint"
        '''
        '''
        obs, obs_valid = adjust_obs_for_extrapolation(
            obs, obs_valid, obs_times, self.args.cut_time)
        obs_times = self.args.ts * obs_times
        '''

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            logpx, analytic_kl = vae_loss(len(numeric_event_ids), 
                torch.cat([truth[:,:,numeric_event_ids], mask_truth[:,:,numeric_event_ids]], dim=-1),
                intermediates['qz0_mean'], intermediates['qz0_logvar'], intermediates['pred_x'][..., numeric_event_ids],
                self.args, self._device)

            #kl_coef = 1.0
            #'''
            wait_until_kl_inc = 10#10
            if epoch < wait_until_kl_inc:
                kl_coef = 0.
            else:
                kl_coef = (1-0.99 ** (epoch - wait_until_kl_inc))
            #'''
            loss = -(torch.logsumexp(logpx - kl_coef * analytic_kl, dim=0).mean(0) - np.log(self.k_iwae))

            # this is new
            '''
            mse_loss = mse(truth[..., numeric_event_ids], intermediates['pred_x'][..., numeric_event_ids].mean(axis=0),
                        mask_truth[..., numeric_event_ids])
            loss = loss + self.alpha * mse_loss
            #loss = mse_loss
            '''

            # compute metric on imputed points only
            mask_imput = (~obs_valid) * mask_truth
            '''
            imput_loss = mse(
                truth, output_mean, mask=mask_imput)
            '''
            #'''
            # remove samples with no extrapolation
            zero_extrap_sequences = mask_imput[..., numeric_event_ids].view(mask_imput.shape[0], -1).sum(axis=1) == 0
            mask_imput_nz = mask_imput[~zero_extrap_sequences][..., numeric_event_ids]
            truth_imput_nz = truth[~zero_extrap_sequences][..., numeric_event_ids]
            imput_logpx, _ = vae_loss(len(numeric_event_ids), torch.cat([truth_imput_nz, mask_imput_nz], dim=-1), 
                intermediates['qz0_mean'][~zero_extrap_sequences],
                intermediates['qz0_logvar'][~zero_extrap_sequences],
                intermediates['pred_x'][:, ~zero_extrap_sequences][..., numeric_event_ids],
                self.args, self._device)

            imput_loss = -torch.logsumexp(imput_logpx, dim=0).mean(0)#.detach().item()
            imput_mse = mse(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], 
                mask=mask_imput[..., numeric_event_ids])#.detach().item()
            imput_mae = mae(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], 
                mask=mask_imput[..., numeric_event_ids])#.detach().item()

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, numeric_event_ids, imput_mae

    # new code component
    def regression(self, data, track_gradient=True):
        """Computes loss on regression task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, input, intermediate variables and computed output
        """
        obs, truth, obs_times, obs_valid = [j.to(self._device) for j in data]
        mask_truth = None
        mask_obs = None
        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            loss = GaussianNegLogLik(
                output_mean, truth, output_var, mask=mask_truth)

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates

    def classification(self, data, track_gradient=True):
        """Computes loss on extrapolation task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, to_classify_mask = [
            j.to(self._device) for j in data[:-1]]

        events_to_report = data[-1]
        classify_mask = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
        classify_mask[to_classify_mask] = True
        '''
        propofol = 445 #31 # high frequency
        phenylephrine = 398#30 # medium frequency
        furosemide = 453 #21 # low frequency
        classify_idx = [furosemide, phenylephrine, propofol]
        classify_mask = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
        classify_mask_pro = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
        classify_mask_phe = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
        classify_mask_fur = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
        classify_mask[classify_idx] = True
        classify_mask_pro[[propofol]] = True
        classify_mask_phe[[phenylephrine]] = True
        classify_mask_fur[[furosemide]] = True
        '''

        # split at an a-point
        #obs_valid *= obs_times <= 0.0
        obs_valid *= obs_times[:,:,None].repeat(1,1,obs.shape[-1]) <= 0.0
        obs = torch.where(obs_valid, obs, 0.)
        assert obs_times[(obs != truth).any(axis=-1)].min().item() > 0.0, \
            "trying to predict before an apoint"

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)

            if not track_gradient: 
                # during eval
                k_truth = truth[None, :, :, :].repeat_interleave(self.k_iwae, 0)
                k_mask_truth = mask_truth[None, :, :, :].repeat_interleave(self.k_iwae, 0)
                gt_apoint_mask = obs_times > 0.0
                for event_name, event_id in events_to_report.items():
                    classify_mask_event = torch.zeros(obs.shape[-1], dtype=torch.bool).to(obs.device)
                    classify_mask_event[event_id] = True
                    pred_mask_event = mask_truth * classify_mask_event[None,None,:] * gt_apoint_mask[:,:,None]
                    intermediates['preds_{}'.format(event_name)] = output_mean[pred_mask_event].detach().cpu()
                    intermediates['label_{}'.format(event_name)] = truth[pred_mask_event].detach().cpu()
                pred_mask_all = mask_truth * classify_mask_event[None,None,:] * gt_apoint_mask[:,:,None]
                intermediates['preds_all'] = output_mean[pred_mask_all].detach().cpu()
                intermediates['label_all'] = truth[pred_mask_all].detach().cpu()
                '''
                pred_mask_pro = k_mask_truth * classify_mask_pro[None, None, None, :] * gt_apoint_mask[None, :, :, None]
                intermediates['preds_propofol'] = intermediates['pred_x'][pred_mask_pro].detach().cpu()
                intermediates['label_propofol'] = k_truth[pred_mask_pro].detach().cpu()

                pred_mask_phe = k_mask_truth * classify_mask_phe[None, None, None, :] * gt_apoint_mask[None, :, :, None]
                intermediates['preds_phenylephrine'] = intermediates['pred_x'][pred_mask_phe].detach().cpu()
                intermediates['label_phenylephrine'] = k_truth[pred_mask_phe].detach().cpu()

                pred_mask_fur = k_mask_truth * classify_mask_fur[None, None, None, :] * gt_apoint_mask[None, :, :, None]
                intermediates['preds_furosemide'] = intermediates['pred_x'][pred_mask_fur].detach().cpu()
                intermediates['label_furosemide'] = k_truth[pred_mask_fur].detach().cpu()

                pred_mask = k_mask_truth * classify_mask[None, None, None, :] * gt_apoint_mask[None, :, :, None]
                intermediates['preds_all'] = intermediates['pred_x'][pred_mask].detach().cpu()
                intermediates['label_all'] = k_truth[pred_mask].detach().cpu()
                '''
            '''
            else:
                # during training
                pred_mask = mask_truth * classify_mask[None, None, :]
                output_mean = output_mean[pred_mask]
                truth = truth[pred_mask]
            '''
            
            logpx, analytic_kl = vae_loss(self.target_dim, 
                #intermediates['train_batch'], 
                torch.cat([truth, mask_truth], dim=-1),
                intermediates['qz0_mean'], intermediates['qz0_logvar'], intermediates['pred_x'],
                self.args, self._device)
            kl_coef = 0.1
            recon_loss = -(torch.logsumexp(logpx - kl_coef * analytic_kl, dim=0).mean(0) - np.log(self.k_iwae))
            clf_tgt = truth[:,:,to_classify_mask].unsqueeze(0).repeat_interleave(self.k_iwae, 0)
            k_mask_truth = mask_truth[None,:,:,to_classify_mask].repeat(self.k_iwae,1,1,1)
            clf_loss = self.classifier_loss(intermediates['pred_y'][k_mask_truth], clf_tgt[k_mask_truth])
            loss = recon_loss + self.alpha * clf_loss

            # compute metric on imputed points only
            mask_imput = (~obs_valid) * mask_truth
            '''
            imput_loss = mse(
                truth, output_mean, mask=mask_imput)
            '''
            #'''
            # remove samples with no extrapolation
            zero_extrap_sequences = mask_imput.view(mask_imput.shape[0], -1).sum(axis=1) == 0
            mask_imput_nz = mask_imput[~zero_extrap_sequences]
            truth_imput_nz = truth[~zero_extrap_sequences]
            imput_logpx, _ = vae_loss(self.target_dim, torch.cat([truth_imput_nz, mask_imput_nz], dim=-1), 
                intermediates['qz0_mean'][~zero_extrap_sequences],
                intermediates['qz0_logvar'][~zero_extrap_sequences],
                intermediates['pred_x'][:, ~zero_extrap_sequences],
                self.args, self._device)

            imput_loss = -torch.logsumexp(imput_logpx, dim=0).mean(0).detach().item()
            imput_mse = mse(truth, output_mean, mask=mask_imput).detach().item()

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, obs_times

    # new code component
    def one_step_ahead_prediction(self, data, track_gradient=True):
        """Computes loss on one-step-ahead prediction

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, input, intermediate variables and computed output
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs = [
            j.to(self._device) for j in data]
        obs_times = self.args.ts * obs_times
        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            output_mean, output_var, truth, mask_truth = align_output_and_target(
                output_mean, output_var, truth, mask_truth)
            loss = GaussianNegLogLik(
                output_mean, truth, output_var, mask=mask_truth)

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates

    # new code component
    def train_epoch(self, dl, optimizer, epoch):
        """Trains model for one epoch 

        :param dl: dataloader containing training data
        :param optimizer: optimizer to use for training
        :return: evaluation metrics, computed output, input, intermediate variables
        """
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        imput_metrics = None

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        if self.args.task == 'extrapolation' or self.args.task == 'interpolation' or self.args.task == 'next_obs_prediction':
            epoch_imput_ll = 0
            epoch_imput_mse = 0
        elif self.args.task == 'classification':
            predictions_epoch = []
            labels_epoch = []

        for i, data in enumerate(dl):

            if self.args.task == 'interpolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse = self.interpolation(
                    data)

            elif self.args.task == 'extrapolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, \
                    numeric_event_ids, _ = self.extrapolation(data, epoch)

            elif self.args.task == 'next_obs_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, \
                    numeric_event_ids, _ = self.next_obs_prediction(data, epoch)

            elif self.args.task == 'classification':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, obs_times = self.classification(
                    data, track_gradient=True)

            elif self.args.task == 'regression':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.regression(
                    data)

            elif self.args.task == 'one_step_ahead_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.one_step_ahead_prediction(
                    data)

            else:
                raise Exception('Unknown task')

            # check for NaNs
            if torch.any(torch.isnan(loss)):
                print('--NAN in loss')
            for name, par in self.named_parameters():
                if torch.any(torch.isnan(par)):
                    print('--NAN before optimiser step in parameter ', name)
            torch.autograd.set_detect_anomaly(
                self.args.anomaly_detection)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            if self.args.grad_clip:
                nn.utils.clip_grad_norm_(self.parameters(), 1)
            optimizer.step()

            # check for NaNs in gradient
            for name, par in self.named_parameters():
                #print('named param: {}, {}'.format(name, par))
                if torch.any(torch.isnan(par.grad)):
                    print('--NAN in gradient ', name)
                if torch.any(torch.isnan(par)):
                    print('--NAN after optimiser step in parameter ', name)

            # aggregate metrics and intermediates over entire epoch
            epoch_ll += loss.item()

            if self.args.task == 'extrapolation' or self.args.task == 'interpolation' or self.args.task == 'next_obs_prediction':
                epoch_rmse += rmse(truth[...,numeric_event_ids], output_mean[..., numeric_event_ids], 
                    mask_truth[..., numeric_event_ids]).item()
                epoch_mse += mse(truth[...,numeric_event_ids], output_mean[...,numeric_event_ids], 
                    mask_truth[...,numeric_event_ids]).item()
                epoch_imput_ll += imput_loss.item()
                epoch_imput_mse += imput_mse.item()
                imput_metrics = [epoch_imput_ll/(i+1), epoch_imput_mse/(i+1)]
            elif self.args.task == 'classification':
                predictions_epoch.append(output_mean.detach().cpu())
                labels_epoch.append(truth.detach().cpu())
            else:
                imput_metrics = None

            if self.args.save_intermediates is not None:
                intermediates_epoch.append(intermediates)
                mask_obs_epoch.append(mask_obs)

        # save for plotting
        if self.args.save_intermediates is not None:
            torch.save(mask_obs_epoch, os.path.join(
                self.args.save_intermediates, 'train_mask_obs.pt'))
            torch.save(intermediates_epoch, os.path.join(
                self.args.save_intermediates, 'train_intermediates.pt'))

        return epoch_ll/(i+1), epoch_rmse/(i+1), epoch_mse/(i+1), [output_mean, output_var], \
                intermediates, [obs, truth, mask_obs], imput_metrics

    # new code component
    def eval_epoch(self, dl, epoch):
        """Evaluates model on the entire dataset

        :param dl: dataloader containing validation or test data
        :return: evaluation metrics, computed output, input, intermediate variables
        """
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        intermediate_results = {}
        imput_metrics = None

        if self.args.task == 'extrapolation' or self.args.task == 'interpolation' or self.args.task == 'next_obs_prediction':
            epoch_imput_ll = 0
            epoch_imput_mse = 0
            epoch_imput_mae = 0

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        for i, data in enumerate(dl):

            if self.args.task == 'interpolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse = self.interpolation(
                    data, track_gradient=False)

            elif self.args.task == 'extrapolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, \
                    numeric_event_ids, imput_mae = self.extrapolation(data, epoch, track_gradient=False)

            elif self.args.task == 'next_obs_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, \
                    numeric_event_ids, imput_mae = self.next_obs_prediction(data, epoch, track_gradient=False)

            elif self.args.task == 'regression':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.regression(
                    data, track_gradient=False)

            elif self.args.task == 'classification':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, obs_times = self.classification(
                    data, track_gradient=False)

            elif self.args.task == 'one_step_ahead_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.one_step_ahead_prediction(
                    data, track_gradient=False)

            epoch_ll += loss.item()
            epoch_rmse += rmse(truth[...,numeric_event_ids], output_mean[...,numeric_event_ids], 
                mask_truth[...,numeric_event_ids]).item()
            epoch_mse += mse(truth[...,numeric_event_ids], output_mean[...,numeric_event_ids], 
                mask_truth[...,numeric_event_ids]).item()
            '''
            decayed_h_states = intermediates['surv_outputs']['decayed_hidden_states']
            h_states = intermediates['surv_outputs']['hidden_states']
            dynamics_mse += mse(h_states[:,1:,:].detach(), 
                decayed_h_states[:, :-1, :]).item()
            '''

            if self.args.task == 'extrapolation' or self.args.task == 'interpolation' or self.args.task == 'next_obs_prediction':
                if self.args.dataset == 'physionet':
                    intermediate_results = compute_physionet_intermediate(mse, mask_imput,
                        truth,output_mean, intermediate_results)
                elif self.args.dataset == 'mimic':
                    intermediate_results = compute_mimic_intermediate(mse, mask_imput,
                        truth,output_mean, intermediate_results)
                epoch_imput_ll += imput_loss
                epoch_imput_mse += imput_mse
                epoch_imput_mae += imput_mae
                imput_metrics = [epoch_imput_ll/(i+1), epoch_imput_mse/(i+1), epoch_imput_mae/(i+1)]
            elif self.args.task == 'classification':
                if self.args.dataset == 'mimic':
                    del intermediates['qz0_mean']
                    del intermediates['qz0_logvar']
                    del intermediates['pred_x']
                    del intermediates['pred_y']
                    intermediate_results = mimic_classification_save_intermediates(output_mean, truth, 
                        intermediate_results, intermediates)
            else:
                imput_metrics = None

            if self.args.save_intermediates is not None:
                mask_obs_epoch.append(mask_obs)
                intermediates_epoch.append(intermediates)

        if self.args.task == 'extrapolation' or self.args.task == 'next_obs_prediction':
            # normalize by batch size
            for k, v in intermediate_results.items():
                intermediates[k] = v / (i+1)
        elif self.args.task == 'classification':
            intermediates = {}
            intermediates['AUPRC_Propofol'] = compute_auprc(intermediate_results['preds_propofol'],intermediate_results['label_propofol'])
            intermediates['AUPRC_Phenylephrine'] = compute_auprc(intermediate_results['preds_phenylephrine'],
                intermediate_results['label_phenylephrine'])
            intermediates['AUPRC_Furosemide'] = compute_auprc(intermediate_results['preds_furosemide'],intermediate_results['label_furosemide'])
            intermediates['AUPRC'] = compute_auprc(intermediate_results['preds_all'], intermediate_results['label_all'])
        #del intermediates['train_batch']
        # save for plotting
        if self.args.save_intermediates is not None:
            torch.save(output_mean, os.path.join(
                self.args.save_intermediates, 'valid_output_mean.pt'))
            torch.save(obs, os.path.join(
                self.args.save_intermediates, 'valid_obs.pt'))
            torch.save(output_var, os.path.join(
                self.args.save_intermediates, 'valid_output_var.pt'))
            torch.save(truth, os.path.join(
                self.args.save_intermediates, 'valid_truth.pt'))
            torch.save(intermediates_epoch, os.path.join(
                self.args.save_intermediates, 'valid_intermediates.pt'))
            torch.save(mask_obs_epoch, os.path.join(
                self.args.save_intermediates, 'valid_mask_obs.pt'))

        return epoch_ll/(i+1), epoch_rmse/(i+1), epoch_mse/(i+1), [output_mean, output_var], intermediates, [obs, truth, mask_obs], imput_metrics

    # new code component
    def run_train(self, train_dl, valid_dl, identifier, logger, epoch_start=0, wandb=None):
        """Trains model on trainset and evaluates on test data. Logs results and saves trained model.

        :param train_dl: training dataloader
        :param valid_dl: validation dataloader
        :param identifier: logger id
        :param logger: logger object
        :param epoch_start: starting epoch
        """

        optimizer = optim.Adam(self.parameters(), self.args.lr)
        def lr_update(epoch): return self.args.lr_decay ** epoch
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lr_update)
        
        make_dir(f'results/tensorboard/{self.args.dataset}')
        writer = SummaryWriter(f'results/tensorboard/{self.args.dataset}/{identifier}')

        for epoch in range(epoch_start, self.args.epochs):
            start = datetime.now()
            logger.info(f'Epoch {epoch} starts: {start.strftime("%H:%M:%S")}')

            # train
            train_ll, train_rmse, train_mse, train_output, intermediates, train_input, train_imput_metrics = self.train_epoch(
                train_dl, optimizer, epoch)
            end_training = datetime.now()
            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='train',
                                metrics=[train_ll, train_rmse, train_mse],
                                output=train_output,
                                input=train_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=train_imput_metrics,
                                log_rythm=self.args.log_rythm)

            # eval
            valid_ll, valid_rmse, valid_mse, valid_output, valid_intermediates, valid_input, valid_imput_metrics = self.eval_epoch(
                valid_dl, epoch)
            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='valid',
                                metrics=[valid_ll, valid_rmse, valid_mse],
                                output=valid_output,
                                input=valid_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=valid_imput_metrics,
                                log_rythm=self.args.log_rythm)

            end = datetime.now()
            logger.info(f'Training epoch {epoch} took: {(end_training - start).total_seconds()}')
            logger.info(f'Epoch {epoch} took: {(end - start).total_seconds()}')
            logger.info(f' train_nll: {train_ll:3f}, train_mse: {train_mse:3f}')
            logger.info(f' valid_nll: {valid_ll:3f}, valid_mse: {valid_mse:3f}')
            wandb_dict = {}
            wandb_dict['train_nll'] = train_ll
            wandb_dict['valid_nll'] = valid_ll
            wandb_dict['train_mse'] = train_mse
            wandb_dict['valid_mse'] = valid_mse
            if self.args.task == 'extrapolation' or self.args.impute_rate is not None or self.args.task == 'next_obs_prediction':
                if self.bernoulli_output:
                    logger.info(f' train_mse_imput: {train_imput_metrics[1]:3f}')
                    logger.info(f' valid_mse_imput: {valid_imput_metrics[1]:3f}')
                else:
                    logger.info(f' train_nll_imput: {train_imput_metrics[0]:3f}, train_mse_imput: {train_imput_metrics[1]:3f}')
                    logger.info(f' valid_nll_imput: {valid_imput_metrics[0]:3f}, valid_mse_imput: {valid_imput_metrics[1]:3f}')
                    wandb_dict['train_nll_imput'] = train_imput_metrics[0]
                    wandb_dict['valid_nll_imput'] = valid_imput_metrics[0]
                    wandb_dict['train_mse_imput'] = train_imput_metrics[1]
                    wandb_dict['valid_mse_imput'] = valid_imput_metrics[1]
                    wandb_dict['valid_mae_imput'] = valid_imput_metrics[2]
                    for k, v in valid_intermediates.items():
                        wandb_dict[k] = v
            elif self.args.task == 'classification':
                for k, v in valid_intermediates.items():
                    wandb_dict[k] = v
            if self.args.log_wandb:
                wandb.log(wandb_dict)

            scheduler.step()
        
        make_dir(f'results/models/{self.args.dataset}')
        '''
        torch.save({'epoch': epoch,
                    'model_state_dict': self.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_ll,
                    }, f'results/models/{self.args.dataset}/{identifier}.tar')
        '''
        torch.save(self, f'results/models/{self.args.dataset}/mtand_{self.args.random_seed}.tar')
