import torch
import numpy as np
from numpy.random import RandomState
prng = RandomState(123)
import time as t
from datetime import datetime
import os
from typing import Tuple
from torch.utils.tensorboard import SummaryWriter
from lib.utils import TimeDistributed, log_to_tensorboard, make_dir, compute_physionet_intermediate, compute_mimic_intermediate
from lib.encoder import Encoder
from lib.decoder import SplitDiagGaussianDecoder, BernoulliDecoder
from lib.CRULayer import CRULayer
from lib.CRUCell import var_activation, var_activation_inverse
from lib.losses import rmse, mse, GaussianNegLogLik, bernoulli_nll, mae
from lib.data_utils import  align_output_and_target, adjust_obs_for_extrapolation, adjust_obs_for_next_obs_pred
from lib import tacd_gru_cell
from collections import defaultdict
import pdb

optim = torch.optim
nn = torch.nn
F = nn.functional


# taken from https://github.com/ALRhub/rkn_share/ and modified
class TACD_GRU(nn.Module):

    # taken from https://github.com/ALRhub/rkn_share/ and modified
    def __init__(self, target_dim: int, lsd: int, args, means=None, 
        use_cuda_if_available: bool = True, 
        bernoulli_output: bool = False,
        use_encoder: bool = False):
        """
        :param target_dim: output dimension
        :param lsd: latent state dimension
        :param args: parsed arguments
        :param use_cuda_if_available: if to use cuda or cpu
        :param use_bernoulli_output: if to use a convolutional decoder (for image data)
        """
        super().__init__()
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() and use_cuda_if_available else "cpu")

        self._lsd = lsd
        '''
        if self._lsd % 2 == 0:
            self._lod = int(self._lsd / 2) 
        else:
            raise Exception('Latent state dimension must be even number.')
        '''
        self.args = args

        self.tacd_norm_time = args.tacd_norm_time

        # parameters TODO: Make configurable
        self.bernoulli_output = bernoulli_output

        self.dropout1 = nn.Dropout(p=0.3).to(self._device)
        self.dropout2 = nn.Dropout(p=0.3).to(self._device)
        self.dropout3 = nn.Dropout(p=0.3).to(self._device)
        self.bidirectional = False

        self.intermediate_linear = nn.Linear(
            self._lsd,
            #2*self._lsd,
            args.gru_intermediate_size).to(self._device).double()
        self.output_mu = nn.Linear(args.gru_intermediate_size, target_dim).to(self._device).double()
        '''
        self.f_out = nn.Sequential(
                        self.dropout1,
                        self.intermediate_linear,
                        nn.ReLU(),
                        self.dropout2,
                        self.output_mu
                     )
        '''
        self.gru = tacd_gru_cell.TACD_GRU(input_size=target_dim,
                        cell_size = self._lsd,
                        hidden_size=self._lsd,
                        X_mean=means,
                        args=args, 
                        use_encoder=use_encoder,
                        f_out=None).to(self._device).double()
        #'''
        '''
        self.gru = learn_dynamics_gru.GRU(args, input_dim=target_dim, hidden_dim=self._lsd, device=self._device)\
            .to(self._device).double()
        '''
        self.combine_linear = nn.Linear(self._lsd, 1).to(self._device).double()

        # params and optimizer
        self._params = list(self.gru.parameters())
        self._params += list(self.intermediate_linear.parameters())
        self._params += list(self.output_mu.parameters())
        self._params += list(self.combine_linear.parameters())
        if self.args.task == 'classification':
            self.n_classes = n_classes = 7
            self.bi_mul = 2 if self.bidirectional else 1 # bi-directional multiplier
            self.classifier_output = nn.Linear(self.bi_mul*self._lsd, n_classes).to(self._device).double()
            self.clf_linear = nn.Linear(self.bi_mul*self._lsd, 1).to(self._device).double()
            self._params += list(self.classifier_output.parameters())
            self._params += list(self.clf_linear.parameters())
            self.classifier_loss = nn.CrossEntropyLoss()
            self.alpha = 10
        if self.bidirectional:
            self.rev_gru = tacd_gru_cell.TACD_GRU(input_size=target_dim,
                            cell_size = self._lsd,
                            hidden_size=self._lsd,
                            X_mean=means,
                            args=args, 
                            use_encoder=use_encoder,
                            f_out=self.f_out).to(self._device).double()
            self.combine_fwd_bwd_logits = nn.Linear(2 * self.n_classes, self.n_classes).to(self._device).double()
            self._params = list(self.rev_gru.parameters())
            self._params = list(self.combine_fwd_bwd_logits.parameters())
            

        self._optimizer = optim.Adam(self._params, lr=self.args.lr)
        self._shuffle_rng = np.random.RandomState(
            42)  # rng for shuffling batches

        if args.dataset == 'mimic':
            self.min_ts = -2*24*60*60  # history
            self.max_ts = 24*60*60 # future
        elif args.dataset == 'ushcn':
            self.min_ts = 14610 
            self.max_ts = 16070
        elif args.dataset == 'physionet':
            self.min_ts =  0#-24
            self.max_ts = 48#24
        elif args.dataset == 'fBM':
            self.min_ts =  0#-24
            self.max_ts = 1#24
        else:
            assert False, "NYI"

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_enc_hidden_layers(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for encoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_dec_hidden_layers_mean(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for mean decoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError

    # taken from https://github.com/ALRhub/rkn_share/ and not modified
    def _build_dec_hidden_layers_var(self) -> Tuple[nn.ModuleList, int]:
        """
        Builds hidden layers for variance decoder
        :return: nn.ModuleList of hidden Layers, size of output of last layer
        """
        raise NotImplementedError

    def add_dummy_times(self, obs_batch, obs_valid, time_points, is_training, hidden_without_dummy):
        if is_training:

            sample_range_start = 1
            sample_range_end = (time_points.size(1) - 2) // 2 
            sample_range = list(range(sample_range_start, sample_range_end))
            dummy_indices = prng.choice(sample_range, size=3, replace=False)
            loss_from_dummy = 0.0
            for dummy_idx in dummy_indices:

                d_obs_batch = obs_batch.clone()
                d_obs_valid = obs_valid.clone()
                d_time_points = time_points.clone()

                # insert 5 consecutive points uniformly at random
                times_to_insert = time_points[:,dummy_idx-1][:,None] + (((time_points[:,dummy_idx] - time_points[:,dummy_idx-1]) / 6)[:,None].repeat(1,5) * torch.tensor([1,2,3,4,5]).cuda())
                d_time_points = torch.cat([time_points[:,:dummy_idx], times_to_insert, time_points[:,dummy_idx:]], 1)
                # dummy points: time_points[:,dummy_idx:dummy_idx+5]
                d_obs_batch = torch.cat([d_obs_batch[:,:dummy_idx, :], torch.zeros((obs_batch.size(0), 5, obs_batch.size(2))).double().cuda(), d_obs_batch[:,dummy_idx:, :]], 1)
                d_obs_valid = torch.cat([d_obs_valid[:,:dummy_idx, :], torch.zeros((obs_batch.size(0), 5, obs_batch.size(2))).bool().cuda(), d_obs_valid[:,dummy_idx:, :]], 1)

                hidden_w_dummy, _ = self.gru(d_obs_batch, d_obs_valid, d_time_points)
                hidden_w_dummy = torch.cat([hidden_w_dummy[:,:dummy_idx, :], hidden_w_dummy[:,dummy_idx+5:, :]], 1)
                loss_from_dummy += ((hidden_without_dummy.detach() - hidden_w_dummy) ** 2).mean()

            return loss_from_dummy
        
    
    # taken from https://github.com/ALRhub/rkn_share/ and modified
    def forward(self, obs_batch: torch.Tensor, time_points: torch.Tensor = None, 
            obs_valid: torch.Tensor = None, is_training=True) -> Tuple[torch.Tensor, torch.Tensor]:
        """Single forward pass on a batch
        :param obs_batch: batch of observation sequences
        :param time_points: timestamps of observations
        :param obs_valid: boolean if timestamp contains valid observation 
        """
        # test to see if something other than GRUD logic is inc
        #obs_valid = torch.ones()
        #out, obs_hat, obs_hat_tm2 = self.gru(obs_batch, obs_valid, time_points)
        if self.tacd_norm_time:
            time_points = ((time_points - self.min_ts) / (self.max_ts - self.min_ts)).double()
        out, obs_hat, clf_logits, comb_weight = self.gru(obs_batch, obs_valid, time_points)
        if self.bidirectional:
            rev_obs_batch = torch.flip(obs_batch, [1])
            rev_obs_valid = torch.flip(obs_valid, [1])
            rev_time_points = -torch.flip(time_points, [1])
            rev_out, rev_obs_hat, rev_clf_logits, _ = self.rev_gru(rev_obs_batch, rev_obs_valid, rev_time_points)
            rev_out = torch.flip(rev_out, [1])
            rev_clf_logits = torch.flip(rev_clf_logits, [1])
            # concatenate forward and backward calls
            out = torch.cat([out, rev_out], dim=-1)
            comb_logits = torch.cat([clf_logits, rev_clf_logits], dim=-1)
            clf_logits = self.combine_fwd_bwd_logits(comb_logits)
            #clf_logits = clf_logits + rev_clf_logits
        
        '''
        # Ablation moved to inside the cell
        if self.args.grudplus_ablation_mode in ['no_ablation', 'no_attention']:
            out_mean = self.f_out(out)
            # combine x_hat and f(x_hat)
            comb_weight = nn.functional.sigmoid(self.combine_linear(self.dropout3(out)))
            # TACD-GRU and no-attention branch
            out_mean = comb_weight * out_mean + (1-comb_weight) * obs_hat

        elif self.args.grudplus_ablation_mode == 'attention_only':
            # attention only run
            out_mean = obs_hat
        '''
        out_mean = obs_hat
        
        classifier_output = None
        if self.args.task == 'classification':
            classifier_output = self.classifier_output(out)
            comb_weight = nn.functional.sigmoid(self.clf_linear(self.dropout3(out)))
            classifier_output = comb_weight * clf_logits + (1-comb_weight) * classifier_output
            #classifier_output = clf_logits
            #classifier_output = classifier_output
            classifier_output = classifier_output.contiguous()
            
        # output an image
        if self.bernoulli_output:
            out_mean = self._dec(post_mean)
            out_var = None

        # output prediction for the next time step
        elif self.args.task == 'one_step_ahead_prediction':
            out_mean, out_var = self._dec(
                prior_mean, torch.cat(prior_cov, dim=-1))

        # output filtered observation
        else:
            '''
            out_mean, out_var = self._dec(
                post_mean, torch.cat(post_cov, dim=-1))
            '''
            out_mean = out_mean
            #out_var = torch.square(out_sd)
            out_var = 1e-6 * torch.ones_like(out_mean)

        intermediates = {
            'obs_hat': obs_hat,
            'combo_weight': comb_weight,
            'classifier_output': classifier_output
        }

        return out_mean, out_var, intermediates

    # new code component
    def interpolation(self, data, track_gradient=True):
        """Computes loss on interpolation task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        #obs_times = self.args.ts * obs_times
        assert (obs == truth).all().item(), "all observations are observed"
        #obs_valid = obs_valid[...,None]

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=mask_obs.bool())

            #obs_hat = intermediates['obs_hat']
            loss = mse(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                    mask=mask_truth[:,:,numeric_event_ids])

            # compute metric on imputed points only
            mask_imput = (~obs_valid[...,None]) * mask_truth
            #imput_loss = GaussianNegLogLik(output_mean, truth, output_var, mask=mask_imput)
            imput_mse = mse(truth[:,:,numeric_event_ids], 
                output_mean[:,:,numeric_event_ids], mask=mask_imput[:,:,numeric_event_ids])
            imput_mae = mae(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                mask=mask_imput[:,:,numeric_event_ids])
            imput_loss = imput_mse
        
        intermediates = {'combo_weight': intermediates['combo_weight'][:,:,0].detach().cpu().numpy()}
        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, \
            imput_loss, imput_mse, imput_mae, numeric_event_ids



    def classification(self, data, track_gradient=True):
        obs, label, obs_times, mask_obs = [
            j.to(self._device) for j in data]
        if obs_times.ndim == 3:
            obs_times = obs_times[:,:,0]
        mask_obs = mask_obs.bool()
        obs = obs.double()
        obs_times = obs_times.double()

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=mask_obs, 
                is_training=track_gradient)

            classifier_output = intermediates['classifier_output']

            recon_loss = mse(obs, output_mean, 
                    mask=mask_obs)
            n_classes = classifier_output.size(-1)
            clf_loss = self.classifier_loss(
                        classifier_output.reshape(-1,n_classes), 
                        label.view(-1, label.size(-1)).argmax(1))
            #loss = recon_loss + self.alpha * clf_loss
            loss = clf_loss

        return loss, label, classifier_output

    def next_obs_prediction(self, data, track_gradient=True):
        """ Next observation prediction

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        if numeric_event_ids.ndim > 1:
            assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
            numeric_event_ids = numeric_event_ids[0,:]

        '''
        obs, obs_valid_extrap, obs_valid, mask_truth, truth, last_indices, \
        obs_times = adjust_obs_for_next_obs_pred(self.args.dataset, obs, 
                obs_valid, mask_obs, mask_truth, truth, obs_times)
        '''
        obs, obs_valid_extrap, obs_valid = adjust_obs_for_extrapolation(self.args.dataset, obs, obs_valid, mask_obs, obs_times)
        cumsum_targets = ((obs_valid != mask_truth).sum(-1).cumsum(-1) > 0).cumsum(axis=1)
        next_pred_time_mask = cumsum_targets == 1
        remove_multi_step_targets = cumsum_targets > 1
        mask_truth = torch.where(~remove_multi_step_targets[...,None], mask_truth, 0.0)
        truth = torch.where(~remove_multi_step_targets[...,None], truth, 0.0)

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid, 
                is_training=track_gradient)

            obs_hat = intermediates['obs_hat']

            # predict input
            next_step_impute_loss = mse(truth, obs_hat, 
                mask=mask_truth)
            #next_step_impute_loss_tm2 = mse(truth, obs_hat_tm2, 
            #    mask=mask_truth)
            next_step_impute_loss = next_step_impute_loss #+ next_step_impute_loss_tm2

            loss = mse(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                    mask=mask_truth[:,:,numeric_event_ids])

            # compute metric on imputed points only
            mask_imput = (~obs_valid) * mask_truth
            imput_loss = GaussianNegLogLik(
                truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                output_var[:,:,numeric_event_ids], mask=mask_imput[:,:,numeric_event_ids])
            imput_mse = mse(truth[:,:,numeric_event_ids], 
                output_mean[:,:,numeric_event_ids], 
                mask=mask_imput[:,:,numeric_event_ids])
            imput_mae = mae(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                mask=mask_imput[:,:,numeric_event_ids])

        intermediates = {'combo_weight': intermediates['combo_weight'][:,:,0].detach().cpu().numpy()}
        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, numeric_event_ids, imput_mae, next_step_impute_loss, obs_times

    

    # new code component
    def extrapolation(self, data, track_gradient=True):
        """Computes loss on extrapolation task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, outputs, inputs, intermediate variables, metrics on imputed points
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs, numeric_event_ids = [
            j.to(self._device) for j in data]

        if numeric_event_ids.ndim > 1:
            assert numeric_event_ids.ndim==2, "more than two dimensions in numeric event ids"
            numeric_event_ids = numeric_event_ids[0,:]

        obs, obs_valid_extrap, obs_valid = adjust_obs_for_extrapolation(self.args.dataset, obs, obs_valid, mask_obs, obs_times)

        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid, 
                is_training=track_gradient)

            obs_hat = intermediates['obs_hat']

            # predict input
            next_step_impute_loss = mse(truth, obs_hat, 
                mask=mask_truth)
            #next_step_impute_loss_tm2 = mse(truth, obs_hat_tm2, 
            #    mask=mask_truth)
            next_step_impute_loss = next_step_impute_loss #+ next_step_impute_loss_tm2

            loss = mse(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                    mask=mask_truth[:,:,numeric_event_ids])

            # compute metric on imputed points only
            mask_imput = (~obs_valid) * mask_truth
            imput_loss = GaussianNegLogLik(
                truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                output_var[:,:,numeric_event_ids], mask=mask_imput[:,:,numeric_event_ids])
            imput_mse = mse(truth[:,:,numeric_event_ids], 
                output_mean[:,:,numeric_event_ids], 
                #(output_mean[:,:,numeric_event_ids] / 4) + 3 * (obs_hat[:,:,numeric_event_ids] / 4), 
                #obs_hat[:,:,numeric_event_ids],
                mask=mask_imput[:,:,numeric_event_ids])
            imput_mae = mae(truth[:,:,numeric_event_ids], output_mean[:,:,numeric_event_ids], 
                mask=mask_imput[:,:,numeric_event_ids])

        intermediates = {'combo_weight': intermediates['combo_weight'][:,:,0].detach().cpu().numpy()}
        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, numeric_event_ids, imput_mae, next_step_impute_loss, obs_times

    # new code component
    def regression(self, data, track_gradient=True):
        """Computes loss on regression task

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, input, intermediate variables and computed output
        """
        obs, truth, obs_times, obs_valid = [j.to(self._device) for j in data]
        mask_truth = None
        mask_obs = None
        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            loss = GaussianNegLogLik(
                output_mean, truth, output_var, mask=mask_truth)

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates

    # new code component
    def one_step_ahead_prediction(self, data, track_gradient=True):
        """Computes loss on one-step-ahead prediction

        :param data: batch of data
        :param track_gradient: if to track gradient for backpropagation
        :return: loss, input, intermediate variables and computed output
        """
        obs, truth, obs_valid, obs_times, mask_truth, mask_obs = [
            j.to(self._device) for j in data]
        obs_times = self.args.ts * obs_times
        with torch.set_grad_enabled(track_gradient):
            output_mean, output_var, intermediates = self.forward(
                obs_batch=obs, time_points=obs_times, obs_valid=obs_valid)
            output_mean, output_var, truth, mask_truth = align_output_and_target(
                output_mean, output_var, truth, mask_truth)
            loss = GaussianNegLogLik(
                output_mean, truth, output_var, mask=mask_truth)

        return loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates

    # new code component
    def train_epoch(self, dl, optimizer):
        """Trains model for one epoch 

        :param dl: dataloader containing training data
        :param optimizer: optimizer to use for training
        :return: evaluation metrics, computed output, input, intermediate variables
        """
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        intermediates = None
        output_mean = None
        output_var = None
        obs = None 
        truth = None 
        mask_obs = None
        imput_metrics = None
        clf_acc = None

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        if self.args.task in ['extrapolation', 'interpolation', 'next_obs_prediction']:
            epoch_imput_ll = 0
            epoch_imput_mse = 0
        elif self.args.task == 'classification':
            epoch_labels = []
            epoch_predictions = []

        for i, data in enumerate(dl):

            #print('processing: {} / {}'.format(i, len(dl)))
            if self.args.task == 'interpolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, \
                intermediates, imput_loss, imput_mse, imput_mae, numeric_event_ids = self.interpolation(
                    data)

            elif self.args.task == 'extrapolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, \
                    mask_imput, numeric_event_ids, _, next_step_impute_loss, _ = self.extrapolation(data)

            elif self.args.task == 'next_obs_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, \
                    mask_imput, numeric_event_ids, _, next_step_impute_loss, _ = self.next_obs_prediction(data)

            elif self.args.task == 'classification':
                loss, classifier_labels, classifier_outputs = self.classification(data)

            elif self.args.task == 'regression':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.regression(
                    data)

            elif self.args.task == 'one_step_ahead_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.one_step_ahead_prediction(
                    data)
            else:
                raise Exception('Unknown task')

            #loss = loss #+ next_step_impute_loss

            # check for NaNs
            if torch.any(torch.isnan(loss)):
                print('--NAN in loss')
            for name, par in self.named_parameters():
                if torch.any(torch.isnan(par)):
                    print('--NAN before optimiser step in parameter ', name)
            torch.autograd.set_detect_anomaly(
                self.args.anomaly_detection)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            if self.args.grad_clip:
                nn.utils.clip_grad_norm_(self.parameters(), 1)
            optimizer.step()

            # check for NaNs in gradient
            for name, par in self.named_parameters():
                if name in ['gru.event_ids', 'gru.hidden_ids']:
                    # stored event ids are static and will not have gradients
                    continue
                #print(name)
                '''
                # bypassing the check to be able to run ablations
                if torch.any(torch.isnan(par.grad)):
                    print('--NAN in gradient ', name)
                '''
                if torch.any(torch.isnan(par)):
                    print('--NAN after optimiser step in parameter ', name)

            # aggregate metrics and intermediates over entire epoch
            epoch_ll += loss.item()

            if self.args.task == 'extrapolation' or self.args.task == 'next_obs_prediction' or self.args.task == 'interpolation':
                epoch_rmse += rmse(truth[...,numeric_event_ids], output_mean[...,numeric_event_ids], 
                    mask_truth[...,numeric_event_ids]).item()
                epoch_mse += mse(truth[...,numeric_event_ids], output_mean[...,numeric_event_ids], 
                    mask_truth[...,numeric_event_ids]).item()
                epoch_imput_ll += imput_loss.item()
                epoch_imput_mse += imput_mse.item()
                imput_metrics = [epoch_imput_ll/(i+1), epoch_imput_mse/(i+1)]
            elif self.args.task == 'classification':
                n_classes = classifier_outputs.size(-1)
                epoch_labels.append(classifier_labels.detach().cpu().view(-1, classifier_labels.size(-1)))
                epoch_predictions.append(classifier_outputs.detach().cpu().view(-1, n_classes))
            else:
                imput_metrics = None

            if self.args.save_intermediates is not None:
                intermediates_epoch.append(intermediates)
                mask_obs_epoch.append(mask_obs)

        # save for plotting
        if self.args.save_intermediates is not None:
            torch.save(mask_obs_epoch, os.path.join(
                self.args.save_intermediates, 'train_mask_obs.pt'))
            torch.save(intermediates_epoch, os.path.join(
                self.args.save_intermediates, 'train_intermediates.pt'))

        # generate stats for classification at the end of epoch
        if self.args.task == 'classification':
            clf_labels = torch.cat(epoch_labels)
            clf_preds = torch.cat(epoch_predictions)
            clf_acc = (clf_preds.argmax(1) == clf_labels.argmax(1)).double().mean().item()

        return epoch_ll/(i+1), epoch_rmse/(i+1), epoch_mse/(i+1), [output_mean, output_var], \
                intermediates, [obs, truth, mask_obs], imput_metrics, clf_acc

    # new code component
    def eval_epoch(self, dl, wandb=None):
        """Evaluates model on the entire dataset

        :param dl: dataloader containing validation or test data
        :return: evaluation metrics, computed output, input, intermediate variables
        """
        epoch_ll = 0
        epoch_rmse = 0
        epoch_mse = 0
        dynamics_mse = 0
        intermediate_results = {'combo_weights_all':[]}
        output_mean = None
        output_var = None
        obs = None 
        truth = None 
        mask_obs = None
        intermediates = None
        imput_metrics = None
        clf_acc = None

        if self.args.task in ['extrapolation', 'interpolation', 'next_obs_prediction']:
            epoch_imput_ll = 0
            epoch_imput_mse = 0
            epoch_imput_mae = 0
        elif self.args.task == 'classification':
            epoch_labels = []
            epoch_predictions = []

        if self.args.save_intermediates is not None:
            mask_obs_epoch = []
            intermediates_epoch = []

        for i, data in enumerate(dl):

            if self.args.task == 'interpolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, \
                imput_loss, imput_mse, imput_mae, numeric_event_ids = self.interpolation(
                    data, track_gradient=False)

            elif self.args.task == 'extrapolation':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, mask_imput, numeric_event_ids, imput_mae, next_step_impute_loss, _ = self.extrapolation(
                    data, track_gradient=False)

            elif self.args.task == 'next_obs_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates, imput_loss, imput_mse, \
                    mask_imput, numeric_event_ids, imput_mae, next_step_impute_loss, _ = self.next_obs_prediction(data, 
                        track_gradient=False)

            elif self.args.task == 'classification':
                loss, classifier_labels, classifier_outputs = self.classification(data, track_gradient=False)

            elif self.args.task == 'regression':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.regression(
                    data, track_gradient=False)

            elif self.args.task == 'one_step_ahead_prediction':
                loss, output_mean, output_var, obs, truth, mask_obs, mask_truth, intermediates = self.one_step_ahead_prediction(
                    data, track_gradient=False)

            epoch_ll += loss
            if self.args.task == 'extrapolation' or self.args.task == 'interpolation' or self.args.task == 'next_obs_prediction':
                epoch_rmse += rmse(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], 
                    mask_truth[..., numeric_event_ids]).item()
                epoch_mse += mse(truth[..., numeric_event_ids], output_mean[..., numeric_event_ids], mask_truth[..., numeric_event_ids]).item()

                if self.args.dataset == 'physionet' and self.args.task in ['extrapolation', 'next_obs_prediction']:
                    intermediate_results = compute_physionet_intermediate(mse, mask_imput,
                        truth,output_mean, intermediate_results)
                elif self.args.dataset == 'mimic':
                    intermediate_results = compute_mimic_intermediate(mse, mask_imput,
                        truth,output_mean, intermediate_results)
                # weights that combine the attention and GRU estimates
                intermediate_results['combo_weights_all'].append(intermediates['combo_weight'])
                epoch_imput_ll += imput_loss
                epoch_imput_mse += imput_mse
                epoch_imput_mae += imput_mae
                imput_metrics = [epoch_imput_ll/(i+1), epoch_imput_mse/(i+1), epoch_imput_mae/(i+1)]

            elif self.args.task == 'classification':
                n_classes = classifier_labels.size(-1)
                epoch_labels.append(classifier_labels.detach().cpu().view(-1, classifier_labels.size(-1)))
                epoch_predictions.append(classifier_outputs.detach().cpu().view(-1, classifier_outputs.size(-1)))
            else:
                imput_metrics = None

            if self.args.save_intermediates is not None:
                intermediates_epoch.append(intermediates)
                mask_obs_epoch.append(mask_obs)

        # normalize by batch size
        if self.args.task == "extrapolation" or self.args.task == 'interpolation':
            for k, v in intermediate_results.items():
                if k != 'combo_weights_all':
                    intermediates[k] = v / (i+1)
                else:
                    comb_weights = intermediate_results['combo_weights_all']
                    max_ts = np.max([comb_weights[i].shape[1] for i in range(len(comb_weights))])
                    stats_by_ts_index = defaultdict(list)
                    mean_comb_by_ts_index = {}
                    for ts_index in range(max_ts):
                        for batch_comb_weights in comb_weights:
                            if ts_index < batch_comb_weights.shape[1]:
                                stats_by_ts_index[ts_index].append(batch_comb_weights[:,ts_index])
                    for ts_index in range(max_ts):
                        mean_comb_by_ts_index[ts_index] = np.concatenate(stats_by_ts_index[ts_index]).mean()
                    data = [[ts_index, mean] for (ts_index, mean) in mean_comb_by_ts_index.items()]
                    wandb_table = wandb.Table(data=data, columns=['ts_index', 'combo_weight_gru'])
                    wandb_plot = wandb.plot.line(wandb_table, "ts_index", "combo_weight_gru", title='GRU combo weight over time indexes')
                    intermediates['combo_wts_gru_by_ts_index'] = wandb_table
                    intermediates['combo_wts_gru_by_ts_index_plot'] = wandb_plot
                    if self.args.task == 'interpolation':
                        all_combo = [intermediate_results['combo_weights_all'][i].flatten() for i in range(len(intermediate_results['combo_weights_all']))]
                        intermediates['avg_combo_wt'] = np.concatenate(all_combo).mean()
                    
         
        # generate stats for classification at the end of epoch
        if self.args.task == 'classification':
            clf_labels = torch.cat(epoch_labels)
            clf_preds = torch.cat(epoch_predictions)
            clf_acc = (clf_preds.argmax(1) == clf_labels.argmax(1)).double().mean().item()

        # save for plotting
        if self.args.save_intermediates is not None:
            torch.save(output_mean, os.path.join(
                self.args.save_intermediates, 'valid_output_mean.pt'))
            torch.save(obs, os.path.join(
                self.args.save_intermediates, 'valid_obs.pt'))
            torch.save(output_var, os.path.join(
                self.args.save_intermediates, 'valid_output_var.pt'))
            torch.save(truth, os.path.join(
                self.args.save_intermediates, 'valid_truth.pt'))
            torch.save(intermediates_epoch, os.path.join(
                self.args.save_intermediates, 'valid_intermediates.pt'))
            torch.save(mask_obs_epoch, os.path.join(
                self.args.save_intermediates, 'valid_mask_obs.pt'))

        return epoch_ll/(i+1), epoch_rmse/(i+1), epoch_mse/(i+1), [output_mean, output_var], \
                intermediates, [obs, truth, mask_obs], imput_metrics, clf_acc

    # new code component
    def run_train(self, train_dl, valid_dl, identifier, logger, epoch_start=0, wandb=None):
        """Trains model on trainset and evaluates on test data. Logs results and saves trained model.

        :param train_dl: training dataloader
        :param valid_dl: validation dataloader
        :param identifier: logger id
        :param logger: logger object
        :param epoch_start: starting epoch
        """

        optimizer = optim.Adam(self.parameters(), self.args.lr)
        def lr_update(epoch): return self.args.lr_decay ** epoch
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lr_update)
        
        make_dir(f'../results/tensorboard/{self.args.dataset}')
        writer = SummaryWriter(f'../results/tensorboard/{self.args.dataset}/{identifier}')

        for epoch in range(epoch_start, self.args.epochs):
            start = datetime.now()
            logger.info(f'Epoch {epoch} starts: {start.strftime("%H:%M:%S")}')

            # train
            train_ll, train_rmse, train_mse, train_output, intermediates, train_input, train_imput_metrics, train_acc = self.train_epoch(
                train_dl, optimizer)
            end_training = datetime.now()
            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='train',
                                metrics=[train_ll, train_rmse, train_mse],
                                output=train_output,
                                input=train_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=train_imput_metrics,
                                log_rythm=self.args.log_rythm)

            # eval
            valid_ll, valid_rmse, valid_mse, valid_output, valid_intermediates, valid_input, valid_imput_metrics, valid_acc = self.eval_epoch(
                valid_dl, wandb=wandb)
            if self.args.tensorboard:
                log_to_tensorboard(self, writer=writer,
                                mode='valid',
                                metrics=[valid_ll, valid_rmse, valid_mse],
                                output=valid_output,
                                input=valid_input,
                                intermediates=intermediates,
                                epoch=epoch,
                                imput_metrics=valid_imput_metrics,
                                log_rythm=self.args.log_rythm)

            end = datetime.now()
            logger.info(f'Training epoch {epoch} took: {(end_training - start).total_seconds()}')
            logger.info(f'Epoch {epoch} took: {(end - start).total_seconds()}')
            logger.info(f' train_nll: {train_ll:3f}, train_mse: {train_mse:3f}')
            logger.info(f' valid_nll: {valid_ll:3f}, valid_mse: {valid_mse:3f}')
            wandb_dict = {}
            wandb_dict['train_nll'] = train_ll
            wandb_dict['valid_nll'] = valid_ll
            if self.args.task == 'extrapolation' or self.args.impute_rate is not None \
                or self.args.task == 'next_obs_prediction' or self.args.task == 'interpolation':
                wandb_dict['train_mse'] = train_mse
                wandb_dict['valid_mse'] = valid_mse
                if self.bernoulli_output:
                    logger.info(f' train_mse_imput: {train_imput_metrics[1]:3f}')
                    logger.info(f' valid_mse_imput: {valid_imput_metrics[1]:3f}')
                else:
                    logger.info(f' train_nll_imput: {train_imput_metrics[0]:3f}, train_mse_imput: {train_imput_metrics[1]:3f}')
                    logger.info(f' valid_nll_imput: {valid_imput_metrics[0]:3f}, valid_mse_imput: {valid_imput_metrics[1]:3f}')
                    wandb_dict['train_nll_imput'] = train_imput_metrics[0]
                    wandb_dict['valid_nll_imput'] = valid_imput_metrics[0]
                    wandb_dict['train_mse_imput'] = train_imput_metrics[1]
                    wandb_dict['valid_mse_imput'] = valid_imput_metrics[1]
                    wandb_dict['valid_mae_imput'] = valid_imput_metrics[2]
                    for k, v in valid_intermediates.items():
                        wandb_dict[k] = v
            elif self.args.task == 'classification':
                wandb_dict['train_accuracy'] = train_acc
                wandb_dict['valid_accuracy'] = valid_acc
            if self.args.log_wandb:
                wandb.log(wandb_dict)

            scheduler.step()
        
        make_dir(f'../results/models/{self.args.dataset}')
        '''
        torch.save({'epoch': epoch,
                    'model_state_dict': self.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_ll,
                    }, f'../results/models/{self.args.dataset}/{identifier}.tar')
        '''
        torch.save(self, f'results/models/{self.args.dataset}/tacd_gru_{self.args.random_seed}_norm_{self.args.tacd_norm_time}.tar')
        #torch.save(self, f'results/models/{self.args.dataset}/tacd_{identifier}_{self.args.grudplus_ablation_mode}_norm_{self.args.tacd_norm_time}.tar')
