# Modeling Irregular Time Series with Continuous Recurrent Units (CRUs)
# Copyright (c) 2022 Robert Bosch GmbH
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
import os
from lib.physionet_preprocessing import download_and_process_physionet
from lib.ushcn_preprocessing import download_and_process_ushcn
#from lib.pendulum_generation import generate_pendulums
#from lib.person_activity import PersonActivity
from sklearn import model_selection
from collections import defaultdict
#from torchcubicspline import(natural_cubic_spline_coeffs, 
#                             NaturalCubicSpline)
import pdb
import pickle as pkl


NCDE_DATASET = ''
MIMIC_HEADERS = None
MIMIC_ITEM_TO_ID = None
MIMIC_IND_IDS = None
MIMIC_NUM_IDS = None
MIMIC_EVENTS_TO_EXCLUDE = None
# new code component 
def load_data(args, device):
    file_path = f'data/{args.dataset}/'
    
    # Pendulum 
    if args.dataset == 'pendulum':
        
        if args.task == 'interpolation':
            if not os.path.exists(os.path.join(file_path, f'pend_interpolation_ir{args.impute_rate}.npz')):
                print(f'Generating pendulum trajectories and saving to {file_path} ...')
                generate_pendulums(file_path, task=args.task, impute_rate=args.impute_rate)

            train = Pendulum_interpolation(file_path=file_path, name=f'pend_interpolation_ir{args.impute_rate}.npz', 
                    mode='train', sample_rate=args.sample_rate, random_state=args.data_random_seed)
            valid = Pendulum_interpolation(file_path=file_path, name=f'pend_interpolation_ir{args.impute_rate}.npz', 
                    mode='valid', sample_rate=args.sample_rate, random_state=args.data_random_seed)
        
        elif args.task =='regression':
            if not os.path.exists(os.path.join(file_path, 'pend_regression.npz')):
                print(f'Generating pendulum trajectories and saving to {file_path} ...')
                generate_pendulums(file_path, task=args.task)

            train = Pendulum_regression(file_path=file_path, name='pend_regression.npz',
                               mode='train', sample_rate=args.sample_rate, random_state=args.data_random_seed)
            valid = Pendulum_regression(file_path=file_path, name='pend_regression.npz',
                               mode='valid', sample_rate=args.sample_rate, random_state=args.data_random_seed)
        else:
            raise Exception('Task not available for Pendulum data')
        collate_fn = None
        
    # USHCN
    elif args.dataset == 'ushcn':
        if not os.path.exists(os.path.join(file_path, 'pivot_train_valid_1990_1993_thr4_normalize.csv')):
            print(f'Downloading USHCN data and saving to {file_path} ...')
            download_and_process_ushcn(file_path)

        '''
        train = USHCN(file_path=file_path, name='pivot_train_1990_1993_thr4_normalize.csv', unobserved_rate=args.unobserved_rate,
                     impute_rate=args.impute_rate, sample_rate=args.sample_rate)
        valid = USHCN(file_path=file_path, name='pivot_valid_1990_1993_thr4_normalize.csv', unobserved_rate=args.unobserved_rate,
                     impute_rate=args.impute_rate, sample_rate=args.sample_rate)
        '''

        #'''
        train = USHCN(file_path=file_path, name='pivot_train_valid_1990_1993_thr4_normalize.csv', unobserved_rate=args.unobserved_rate,
                     impute_rate=args.impute_rate, sample_rate=args.sample_rate)
        valid = USHCN(file_path=file_path, name='pivot_test_1990_1993_thr4_normalize.csv', unobserved_rate=args.unobserved_rate,
                     impute_rate=args.impute_rate, sample_rate=args.sample_rate)
        #'''
        collate_fn = None
    
    # Physionet
    elif args.dataset == 'physionet':
        if not os.path.exists(os.path.join(file_path, 'norm_train_valid.pt')):
            print(f'Downloading Physionet data and saving to {file_path} ...')
            download_and_process_physionet(file_path)

        #'''
        train = Physionet(file_path=file_path, name='norm_train_valid.pt')
        valid = Physionet(file_path=file_path, name='norm_test.pt')
        #'''
        '''
        train = Physionet(file_path=file_path, name='norm_train.pt')
        valid = Physionet(file_path=file_path, name='norm_valid.pt')
        '''
        collate_fn = collate_fn_physionet
        #collate_fn = collate_fn_physionet_full_valid_mask

    # MIMIC
    elif args.dataset == 'mimic':
        global MIMIC_HEADERS
        global MIMIC_ITEM_TO_ID
        global MIMIC_IND_IDS
        global MIMIC_NUM_IDS
        global MIMIC_EVENTS_TO_EXCLUDE
        MIMIC_HEADERS = torch.load(os.path.join(file_path, 'feature_headers.pt')) 
        MIMIC_IND_IDS = [i for i,feat in enumerate(MIMIC_HEADERS)  if not feat.startswith('chart_events_') and not feat.startswith('lab_events_')]
        MIMIC_NUM_IDS = list(set(range(len(MIMIC_HEADERS))) - set(MIMIC_IND_IDS))
        MIMIC_ITEM_TO_ID = {v:i for i, v in enumerate(MIMIC_HEADERS)}
        #print('using val split for train dataset -- change this later')
        '''
        train = MIMIC(file_path=file_path, name='subset_5k/train_f_86400_p_86400_afreq_43200.pt')
        valid = MIMIC(file_path=file_path, name='subset_5k/val_f_86400_p_86400_afreq_43200.pt')
        '''

        '''
        # after the new dataset
        train = MIMIC(file_path=file_path, name='subset_5k_new/train_f_86400_p_86400_afreq_172800.pt')
        valid = MIMIC(file_path=file_path, name='subset_5k_new/val_f_86400_p_86400_afreq_172800.pt')
        '''

        '''
        train = MIMIC(file_path=file_path, name='subset_5k_new/train_val_f_86400_p_86400_afreq_172800.pt')
        valid = MIMIC(file_path=file_path, name='subset_5k_new/test_f_86400_p_86400_afreq_172800.pt')
        '''

        train = MIMIC(file_path=file_path, name='subsampled_1k/train_val.pt')
        valid = MIMIC(file_path=file_path, name='subsampled_1k/test.pt')
        

        # all types: chart_*, lab_*, input_*, proc_*
        MIMIC_EVENTS_TO_EXCLUDE = []
        for ev_type in args.remove:
            MIMIC_EVENTS_TO_EXCLUDE += [i for i,feat in enumerate(MIMIC_HEADERS)  if feat.startswith(ev_type)] 

        '''
        if args.cde:
            all_times = mimic_get_all_times(train, valid)
            train = MIMIC_CDE(train[:5], all_times)
            valid = MIMIC_CDE(valid[:5], all_times)
        '''

        if args.task == 'extrapolation' or args.task == 'next_obs_prediction':
            collate_fn = collate_fn_mimic_extrapolation
        elif args.task == 'classification':
            collate_fn = collate_fn_mimic_classification

    elif args.dataset == 'fBM':
        train = fBM('data/train_fBM_0.75.pkl')
        valid = fBM('data/test_fBM_0.75.pkl')
        collate_fn = collate_fn_fBM_extrapolation

    # Human activity dataset
    elif args.dataset == 'activity':
        #n_samples = min(10000, 1000)
        n_samples = min(10000, 10000)
        dataset_obj = PersonActivity('data/PersonActivity',
                                 download=True, n_samples=n_samples,
                                 device=device)
        print(dataset_obj) 
        train_data, test_data = model_selection.train_test_split(dataset_obj, train_size=0.8,
                                                             random_state=42, shuffle=True)
        
        record_id, tt, vals, mask, labels = train_data[0]
        input_dim = vals.size(-1)
        batch_size = min(min(len(dataset_obj), args.batch_size), 100)
        test_data_combined = variable_time_collate_fn(test_data, device, classify=True,
                                                  activity=True)
        #'''
        # train / test split
        train_data_combined = variable_time_collate_fn(
            train_data, device, classify=True, activity=True)

        train_data_combined = TensorDataset(
            train_data_combined[0], train_data_combined[1].long(), train_data_combined[2], train_data_combined[3])
        test_data_combined = TensorDataset(
            test_data_combined[0], test_data_combined[1].long(), test_data_combined[2], test_data_combined[3])

        train_dataloader = DataLoader(
            train_data_combined, batch_size=args.batch_size, shuffle=False)
        test_dataloader = DataLoader(
            test_data_combined, batch_size=args.batch_size, shuffle=False)
        val_dataloader = test_dataloader
        #'''
        '''
        # train / validation / test split
        train_data, val_data = model_selection.train_test_split(train_data, train_size=0.8,
                                                            random_state=11, shuffle=True)
        train_data_combined = variable_time_collate_fn(
            train_data, device, classify=True, activity=True)
        val_data_combined = variable_time_collate_fn(
            val_data, device, classify=True, activity=True)

        train_data_combined = TensorDataset(
            train_data_combined[0], train_data_combined[1].long(), train_data_combined[2], train_data_combined[3])
        val_data_combined = TensorDataset(
            val_data_combined[0], val_data_combined[1].long(), val_data_combined[2], val_data_combined[3])
        test_data_combined = TensorDataset(
            test_data_combined[0], test_data_combined[1].long(), test_data_combined[2], test_data_combined[3])

        train_dataloader = DataLoader(
            train_data_combined, batch_size=args.batch_size, shuffle=False)
        test_dataloader = DataLoader(
            test_data_combined, batch_size=args.batch_size, shuffle=False)
        val_dataloader = DataLoader(
            val_data_combined, batch_size=args.batch_size, shuffle=False)
        '''
        return train_dataloader, val_dataloader
    
    train_dl = DataLoader(train, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=args.pin_memory)
    valid_dl = DataLoader(valid, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=args.pin_memory)

    return train_dl, valid_dl

def mimic_get_all_times(train, valid):
    all_times = []
    for sample in train:
        times, obs_values, apoint_info = sample
        all_times = all_times + times
    for sample in valid:
        times, obs_values, apoint_info = sample
        all_times = all_times + times
    return np.array(sorted(list(set(all_times))))
        

def variable_time_collate_fn(batch, device=torch.device("cpu"), classify=False, activity=False,
                             data_min=None, data_max=None):
    """
    Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where
      - record_id is a patient id
      - tt is a 1-dimensional tensor containing T time values of observations.
      - vals is a (T, D) tensor containing observed values for D variables.
      - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
      - labels is a list of labels for the current patient, if labels are available. Otherwise None.
    Returns:
      combined_tt: The union of all time observations.
      combined_vals: (M, T, D) tensor containing the observed values.
      combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
    """
    D = batch[0][2].shape[1]
    # number of labels
    N = batch[0][-1].shape[1] if activity else 1
    len_tt = [ex[1].size(0) for ex in batch]
    maxlen = np.max(len_tt)
    enc_combined_tt = torch.zeros([len(batch), maxlen]).to(device)
    enc_combined_vals = torch.zeros([len(batch), maxlen, D]).to(device)
    enc_combined_mask = torch.zeros([len(batch), maxlen, D]).to(device)
    if classify:
        if activity:
            combined_labels = torch.zeros([len(batch), maxlen, N]).to(device)
        else:
            combined_labels = torch.zeros([len(batch), N]).to(device)

    for b, (record_id, tt, vals, mask, labels) in enumerate(batch):
        currlen = tt.size(0)
        enc_combined_tt[b, :currlen] = tt.to(device)
        enc_combined_vals[b, :currlen] = vals.to(device)
        enc_combined_mask[b, :currlen] = mask.to(device)
        if classify:
            if activity:
                combined_labels[b, :currlen] = labels.to(device)
            else:
                combined_labels[b] = labels.to(device)

    if not activity:
        enc_combined_vals, _, _ = normalize_masked_data(enc_combined_vals, enc_combined_mask,
                                                        att_min=data_min, att_max=data_max)

    if torch.max(enc_combined_tt) != 0.:
        enc_combined_tt = enc_combined_tt / torch.max(enc_combined_tt)

    combined_data = torch.cat(
        (enc_combined_vals, enc_combined_mask, enc_combined_tt.unsqueeze(-1)), 2)
    if classify:
        #return combined_data, combined_labels
        return enc_combined_vals, combined_labels, enc_combined_tt.unsqueeze(-1), enc_combined_mask
    else:
        return combined_data


# new code component 
class Pendulum_interpolation(Dataset):
    def __init__(self, file_path, name, mode, sample_rate=0.5, random_state=0):

        data = dict(np.load(os.path.join(file_path, name)))
        train_obs, train_targets, train_time_points, train_obs_valid, \
            test_obs, test_targets, test_time_points, test_obs_valid = subsample(
                data, sample_rate=sample_rate, imagepred=True, random_state=random_state)

        if mode == 'train':
            self.obs = train_obs
            self.targets = train_targets
            self.obs_valid = train_obs_valid
            self.time_points = train_time_points

        else:
            self.obs = test_obs
            self.targets = test_targets
            self.obs_valid = test_obs_valid
            self.time_points = test_time_points

        self.obs = np.ascontiguousarray(
            np.transpose(self.obs, [0, 1, 4, 2, 3]))/255.0
        self.targets = np.ascontiguousarray(
            np.transpose(self.targets, [0, 1, 4, 2, 3]))/255.0
        self.obs_valid = np.squeeze(self.obs_valid, axis=2)

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self, idx):
        obs = torch.from_numpy(self.obs[idx, ...].astype(np.float64))
        targets = torch.from_numpy(self.targets[idx, ...].astype(np.float64))
        obs_valid = torch.from_numpy(self.obs_valid[idx, ...])
        time_points = torch.from_numpy(self.time_points[idx, ...])
        mask_truth = torch.ones_like(targets)
        return obs, targets, obs_valid, time_points, mask_truth


# new code component 
class Pendulum_regression(Dataset):
    def __init__(self, file_path, name, mode, sample_rate=0.5, random_state=0):

        data = dict(np.load(os.path.join(file_path, name)))
        train_obs, train_targets, test_obs, test_targets, train_time_points, \
            test_time_points = subsample(
                data, sample_rate=sample_rate, random_state=random_state)

        if mode == 'train':
            self.obs = train_obs
            self.targets = train_targets
            self.time_points = train_time_points
        else:
            self.obs = test_obs
            self.targets = test_targets
            self.time_points = test_time_points

        self.obs = np.ascontiguousarray(
            np.transpose(self.obs, [0, 1, 4, 2, 3]))/255.0

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self, idx):
        obs = torch.from_numpy(self.obs[idx, ...].astype(np.float64))
        targets = torch.from_numpy(self.targets[idx, ...].astype(np.float64))
        time_points = torch.from_numpy(self.time_points[idx, ...])
        obs_valid = torch.ones_like(time_points, dtype=torch.bool)
        return obs, targets, time_points, obs_valid


# new code component 
class USHCN(Dataset):

    params = ['PRCP', 'SNOW', 'SNWD', 'TMAX', 'TMIN']

    def __init__(self, file_path, name, impute_rate=None, sample_rate=0.5, columns=[0, 1, 2, 3, 4], unobserved_rate=None, year_range=4):
        self.sample_rate = sample_rate
        self.impute_rate = impute_rate
        self.unobserved_rate = unobserved_rate
        self.year_range = year_range
        self.data = pd.read_csv(
            file_path + name).sort_values(['UNIQUE_ID', 'TIME_STAMP']).set_index('UNIQUE_ID')
        self.label_columns = self.data.columns[self.data.columns.isin(
            [str(i) for i in columns])]

    def __len__(self):
        return self.data.index.nunique()

    def subsample_time_points(self, sample, n_total_time_points, n_sample_time_points, seed=0):
        rng = np.random.RandomState(seed)

        choice = np.sort(rng.choice(n_total_time_points,
                         n_sample_time_points, replace=False))
        return sample.loc[choice]

    def subsample_features(self, n_features, n_sample_time_points, seed=0):
        rng = np.random.RandomState(seed)

        # no subsampling
        if self.unobserved_rate is None:
            unobserved_mask = np.full(
                (n_features, n_sample_time_points), False, dtype=bool)

        # subsample such that it is equally probable that 1, 2, 3,.. features are missing per time point
        if self.unobserved_rate == 'stratified':
            unobserved_mask = create_unobserved_mask(
                n_features, n_sample_time_points)

        # subsample features based on overall rate (most time points will have 1, 2 features missing, few will have more missing)
        elif isinstance(self.unobserved_rate, float) or isinstance(self.unobserved_rate, int):
            assert 0 <= self.unobserved_rate < 1, 'Unobserved rate must be between 0 and 1.'
            unobserved_mask = ~ (
                rng.rand(n_sample_time_points, n_features) > self.unobserved_rate)

        else:
            raise Exception('Unobserved mode unknown')
        return unobserved_mask

    def get_data_based_on_impute_rate(self, sample, unobserved_mask, n_features, n_sample_time_points):

        # task is not imputation (i.e. extrapolation or one-step-ahead prediction)
        if self.impute_rate is None:
            sample[self.label_columns] = np.where(
                unobserved_mask, np.nan, sample[self.label_columns])
            obs = torch.tensor(sample.loc[:, self.label_columns].values)
            targets = obs.clone()
            # valid if we have at least one dim observed
            obs_valid = np.sum(unobserved_mask, axis=-1) < n_features

        # impute missing time step
        elif isinstance(self.impute_rate, float):
            assert 0 <= self.impute_rate < 1, 'Imputation rate must be between 0 and 1.'
            sample[self.label_columns] = np.where(
                unobserved_mask, np.nan, sample[self.label_columns])
            obs = torch.tensor(sample.loc[:, self.label_columns].values)
            targets = obs.clone()
            # remove time steps that have to be imputed
            obs_valid = torch.rand(n_sample_time_points) >= self.impute_rate
            obs_valid[:10] = True
            obs[~obs_valid] = np.nan

        else:
            raise Exception('Impute mode unknown')

        time_points = torch.tensor(sample.loc[:, 'TIME_STAMP'].values)
        mask_targets = 1*~targets.isnan()
        mask_obs = ~ unobserved_mask

        return torch.nan_to_num(obs), torch.nan_to_num(targets), obs_valid, time_points, mask_targets, mask_obs

    def __getitem__(self, idx):
        sample = self.data.loc[idx, :].reset_index(drop=True)
        n_total_time_points = len(sample)
        n_sample_time_points = int(365 * self.year_range * self.sample_rate)
        n_features = len(self.label_columns)

        # subsample time points to increase irregularity
        sample = self.subsample_time_points(
            sample, n_total_time_points, n_sample_time_points)

        # subsample features to increase partial observability
        unobserved_mask = self.subsample_features(
            n_features, n_sample_time_points)

        # create masks and target based on if/what kind of imputation
        obs, targets, obs_valid, time_points, mask_targets, mask_obs = \
            self.get_data_based_on_impute_rate(
                sample, unobserved_mask, n_features, n_sample_time_points)

        numeric_event_ids = torch.tensor(np.array(list(range(obs.shape[-1]))))

        # divide the sequence in half by subtracting times - times[midpoint]
        #extrapolation_start_index = len(time_points) // 2
        #time_points = time_points - time_points[extrapolation_start_index]
        time_points = time_points.double()
        return obs, targets, obs_valid, time_points, mask_targets, mask_obs, numeric_event_ids

class fBM(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'rb') as f:
            data = pkl.load(f)
        self.times = torch.from_numpy(np.array(data['times']))
        self.values = torch.from_numpy(np.array(data['normalized_values']))

    def __len__(self):
        return len(self.times)

    def __getitem__(self, index):
        return self.times[index], self.values[index]

class MIMIC(Dataset):
    def __init__(self, file_path, name):
        self.data = torch.load(os.path.join(
            file_path, name), map_location='cpu')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

class MIMIC_CDE(Dataset):
    def __init__(self, dataset, all_times):
        n_feats = len(MIMIC_HEADERS)
        tabular_samples = []
        time_points = []
        get_index = {ts: i for i, ts in enumerate(all_times)}
        cut_time = 0 # used for extrapolation
        sample_ids = []
        for sample in dataset:
            sample_ids.append(sample[-1][0])
            times, obs_values, apoint_info = sample
            ts_feats = torch.tensor([[np.NaN]*n_feats]*len(all_times), dtype=torch.double)
            ts_feats[all_times <= cut_time][:, MIMIC_IND_IDS] = 0.0
            #time_points.append(torch.tensor(times, dtype=torch.double))
            for ts_id in range(len(times)):
                for k, v in obs_values[ts_id].items():
                    # apply extrapolation mask
                    if times[ts_id] <= cut_time:
                        ts_feats[get_index[times[ts_id]], MIMIC_ITEM_TO_ID[k]] = v
            tabular_samples.append(ts_feats)
        obs = pad_sequence(tabular_samples, batch_first=True, padding_value=np.NaN).to(
            device='cpu', dtype=torch.double)
        obs[..., MIMIC_EVENTS_TO_EXCLUDE] = np.NaN
        #time_points = pad_sequence(time_points, batch_first=True).to(
        #    device='cpu', dtype=torch.double)
        mask_obs = (~obs.isnan())
        obs = obs.nan_to_num(nan=0.0)
        '''
        # paddding zeros to last time
        #n_dim, t_dim = time_points.shape
        #rng = torch.arange(t_dim)
        #rng_2d = rng.unsqueeze(0).repeat(n_dim, 1)
        max_vals, max_index = time_points.max(axis=1) 
        #max_vals_matrix = max_vals[:,None] * torch.ones_like(time_points)
        for row_i in range(time_points.size(0)):
            time_points[row_i, max_index[row_i]+1:] = max_vals[row_i]
        '''
        
        # augmented for cde interpolation
        augmented_X = []
        augmented_X.append(torch.tensor(all_times).unsqueeze(0).repeat(obs.size(0), 1)[:,:,None]) # add time
        augmented_X.append(mask_obs.cumsum(dim=1)) # add intensities
        augmented_X.append(obs)
        augmented_X = torch.cat(augmented_X, dim=2)
        # interpolate 
        coeffs = natural_cubic_spline_coeffs(torch.tensor(all_times), augmented_X)
        for i, sample_id in enumerate(sample_ids):
            fname = os.path.join(NCDE_DATASET, '{}.pkl'.format(sample_id))
            print('\t saving: {}'.format(fname))
            pdb.set_trace()
            torch.save([coeffs[0], coeffs[1][i], coeffs[2][i], coeffs[3][i], coeffs[4][i]], fname)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

def collate_fn_fBM_extrapolation(batch):
    n_feats = 1
    time_points = torch.cat([sample[0][None,...] for sample in batch])
    obs = torch.cat([sample[1][None,...] for sample in batch])[...,None]
    targets = obs.clone()
    mask_obs = (~obs.isnan())
    mask_targets = mask_obs.clone()
    # true if observed; false if not
    obs_valid = mask_obs
    # to be consistent with the codebase
    obs_valid = obs_valid.any(axis=-1)
    numeric_event_ids = torch.tensor(np.array(list(range(obs.shape[-1]))))
    return obs, targets, obs_valid, time_points, mask_targets, mask_obs, numeric_event_ids

def collate_fn_mimic_extrapolation(batch):
    n_feats = len(MIMIC_HEADERS)
    tabular_samples = []
    time_points = []
    for sample in batch:
        times, obs_values, apoint_info = sample
        ts_feats = torch.tensor([[np.NaN]*n_feats]*len(times), dtype=torch.double)
        ts_feats[:, MIMIC_IND_IDS] = 0.0
        time_points.append(torch.tensor(times, dtype=torch.double))
        for ts_id in range(len(times)):
            for k, v in obs_values[ts_id].items():
                ts_feats[ts_id, MIMIC_ITEM_TO_ID[k]] = v
        tabular_samples.append(ts_feats)

    obs = pad_sequence(tabular_samples, batch_first=True, padding_value=np.NaN).to(
        device='cpu', dtype=torch.double)
    obs[...,MIMIC_EVENTS_TO_EXCLUDE] = np.NaN
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = (~obs.isnan())
    obs = obs.nan_to_num(nan=0.0)

    targets = obs.clone()
    mask_targets = mask_obs.clone()

    # true if observed; false if not
    obs_valid = mask_obs
    # to be consistent with the codebase
    obs_valid = obs_valid.any(axis=-1)

    return obs, targets, obs_valid, time_points, mask_targets, mask_obs, torch.tensor(MIMIC_NUM_IDS)

def old_collate_fn_mimic_full_valid_mask(batch):
    obs = [obs for time_points, obs, mask, apoint in batch]
    time_points = [time_points for time_points, obs, mask, apoint in batch]
    mask = [mask for time_points, obs, mask, apoint in batch]

    obs = pad_sequence(obs, batch_first=True).to(
        device='cpu', dtype=torch.double)
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = pad_sequence(mask, batch_first=True).to(device='cpu')
    targets = obs.clone()
    mask_targets = mask_obs.clone()

    #obs_valid = ~torch.all(mask_obs == 0, dim=-1)
    #obs_valid = ~(mask_obs==0)
    # true if observed; false if not
    obs_valid = mask_obs

    return obs, targets, obs_valid, time_points, mask_targets, mask_obs

def collate_fn_mimic_classification(batch):
    n_feats = len(MIMIC_HEADERS)
    tabular_samples = []
    time_points = []
    # headers to ids can change so use text names
    events_to_report = {
        'furosemide': MIMIC_ITEM_TO_ID['input_events_221794'], 
        'propofol': MIMIC_ITEM_TO_ID['input_events_222168'],
        'phenylephrine': MIMIC_ITEM_TO_ID['input_events_221749'],
        'inv_ventilation': MIMIC_ITEM_TO_ID['proc_events_225792'],
        'bronchoscopy': MIMIC_ITEM_TO_ID['proc_events_225400'],
        'mri': MIMIC_ITEM_TO_ID['proc_events_223253']
    }
    for sample in batch:
        times, obs_values, apoint_info = sample
        ts_feats = torch.tensor([[np.NaN]*n_feats]*len(times), dtype=torch.double)
        ts_feats[:, MIMIC_IND_IDS] = 0.0
        time_points.append(torch.tensor(times, dtype=torch.double))
        for ts_id in range(len(times)):
            for k, v in obs_values[ts_id].items():
                ts_feats[ts_id, MIMIC_ITEM_TO_ID[k]] = v
        tabular_samples.append(ts_feats)
    #obs = torch.from_numpy(np.array(tabular_samples))

    #obs = [obs for time_points, obs, mask, apoint in batch]
    #time_points = [time_points for time_points, obs, mask, apoint in batch]
    #mask = [mask for time_points, obs, mask, apoint in batch]

    obs = pad_sequence(tabular_samples, batch_first=True, padding_value=np.NaN).to(
        device='cpu', dtype=torch.double)
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = (~obs.isnan())
    obs = obs.nan_to_num(nan=0.0)

    '''
    obs = pad_sequence(obs, batch_first=True).to(
        device='cpu', dtype=torch.double)
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = pad_sequence(mask, batch_first=True).to(device='cpu')
    '''
    targets = obs.clone()
    mask_targets = mask_obs.clone()

    #obs_valid = ~torch.all(mask_obs == 0, dim=-1)
    #obs_valid = ~(mask_obs==0)
    # true if observed; false if not
    obs_valid = mask_obs

    return obs, targets, obs_valid, time_points, mask_targets, mask_obs, torch.tensor(MIMIC_IND_IDS), events_to_report

# new code component 
class Physionet(Dataset):
    def __init__(self, file_path, name):
        self.data = torch.load(os.path.join(
            file_path, name), map_location='cpu')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

def collate_fn_physionet_full_valid_mask(batch):
    obs = [obs for patient_id, time_points, obs, mask, label in batch]
    time_points = [time_points for patient_id,
                   time_points, obs, mask, label in batch]
    mask = [mask for patient_id, time_points, obs, mask, label in batch]

    obs = pad_sequence(obs, batch_first=True).to(
        device='cpu', dtype=torch.double)
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = pad_sequence(mask, batch_first=True).to(device='cpu')
    targets = obs.clone()
    mask_targets = mask_obs.clone()

    #obs_valid = ~torch.all(mask_obs == 0, dim=-1)
    obs_valid = ~(mask_obs==0)

    return obs, targets, obs_valid, time_points, mask_targets, mask_obs

# new code component 
def collate_fn_physionet(batch):

    obs = [obs for patient_id, time_points, obs, mask, label in batch]
    time_points = [time_points for patient_id,
                   time_points, obs, mask, label in batch]
    mask = [mask for patient_id, time_points, obs, mask, label in batch]

    obs = pad_sequence(obs, batch_first=True).to(
        device='cpu', dtype=torch.double)
    time_points = pad_sequence(time_points, batch_first=True).to(
        device='cpu', dtype=torch.double)
    mask_obs = pad_sequence(mask, batch_first=True).to(device='cpu')
    targets = obs.clone()
    mask_targets = mask_obs.clone()

    # create obs_valid mask such that update step will be skipped on padded time points
    obs_valid = ~torch.all(mask_obs == 0, dim=-1)
    
    # for extrapolation: time_points range from [0, 48]
    # subtract 24 for exxtrapolation method to extrapolate for
    # timepoints > 24 hours
    #time_points = time_points - 24

    # every obs is numeric
    numeric_event_ids = torch.tensor(np.array(list(range(obs.shape[-1]))))

    return obs, targets, obs_valid, time_points, mask_targets, mask_obs, numeric_event_ids


# new code component 
def subsample(data, sample_rate, imagepred=False, random_state=0):
    train_obs, train_targets, test_obs, test_targets = data["train_obs"], \
        data["train_targets"], data["test_obs"], data["test_targets"]
    seq_length = train_obs.shape[1]
    train_time_points = []
    test_time_points = []
    n = int(sample_rate*seq_length)

    if imagepred:
        train_obs_valid = data["train_obs_valid"]
        test_obs_valid = data["test_obs_valid"]
        data_components = train_obs, train_targets, test_obs, test_targets, train_obs_valid, test_obs_valid
        train_obs_sub, train_targets_sub, test_obs_sub, test_targets_sub, train_obs_valid_sub, test_obs_valid_sub = [
            np.zeros_like(x[:, :n, ...]) for x in data_components]
    else:
        data_components = train_obs, train_targets, test_obs, test_targets
        train_obs_sub, train_targets_sub, test_obs_sub, test_targets_sub = [
            np.zeros_like(x[:, :n, ...]) for x in data_components]

    for i in range(train_obs.shape[0]):
        rng_train = np.random.default_rng(random_state+i+train_obs.shape[0])
        choice = np.sort(rng_train.choice(seq_length, n, replace=False))
        train_time_points.append(choice)
        train_obs_sub[i, ...], train_targets_sub[i, ...] = [
            x[i, choice, ...] for x in [train_obs, train_targets]]
        if imagepred:
            train_obs_valid_sub[i, ...] = train_obs_valid[i, choice, ...]

    for i in range(test_obs.shape[0]):
        rng_test = np.random.default_rng(random_state+i)
        choice = np.sort(rng_test.choice(seq_length, n, replace=False))
        test_time_points.append(choice)
        test_obs_sub[i, ...], test_targets_sub[i, ...] = [
            x[i, choice, ...] for x in [test_obs, test_targets]]
        if imagepred:
            test_obs_valid_sub[i, ...] = test_obs_valid[i, choice, ...]

    train_time_points, test_time_points = np.stack(
        train_time_points, 0), np.stack(test_time_points, 0)

    if imagepred:
        return train_obs_sub, train_targets_sub, train_time_points, train_obs_valid_sub, test_obs_sub, test_targets_sub, test_time_points, test_obs_valid_sub
    else:
        return train_obs_sub, train_targets_sub, test_obs_sub, test_targets_sub, train_time_points, test_time_points



# new code component 
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


# new code component 
def discretize_data(obs, targets, time_points, obs_valid, n_bins=10, take_always_closest=True):
    N = obs.shape[0]
    T_max = time_points.max()
    bin_length = T_max/n_bins
    obs_valid = np.squeeze(obs_valid)

    # define the bins
    _, bin_edges = np.histogram(time_points, bins=n_bins)

    # get the center of each bin
    bin_length = bin_edges[1] - bin_edges[0]
    bin_center = bin_edges + bin_length/2

    # get the timepoint, obs etc that is closest to the bin center
    tp_all = []
    obs_valid_all = []
    obs_all = np.zeros((N, n_bins, 24, 24, 1), dtype='uint8')
    targets_all = np.zeros((N, n_bins, 24, 24, 1), dtype='uint8')
    for i in range(N):
        tp_list = []
        obs_valid_list = []
        for j in range(n_bins):
            sample_tp = time_points[i, :]
            center = bin_center[j]
            idx = find_nearest(sample_tp, center)
            if (bin_edges[j] <= sample_tp[idx] <= bin_edges[j+1]) or take_always_closest:
                tp_list.append(sample_tp[idx])
                obs_valid_list.append(obs_valid[i, idx])
                obs_all[i, j, ...] = obs[i, idx, ...]
                targets_all[i, j, ...] = targets[i, idx, ...]
            else:
                tp_list.append(np.nan)
                obs_valid_list.append(False)
                obs_all[i, j, ...] = 0
                targets_all[i, j, ...] = 0

        tp_all.append(tp_list)
        obs_valid_all.append(obs_valid_list)

    return obs_all, targets_all, np.array(tp_all), np.array(obs_valid_all)


# new code component 
def create_unobserved_mask(n_col, T, seed=0):
    # subsamples features (used to experiment with partial observability on USHCN)
    rng = np.random.RandomState(seed)
    mask = []
    for i in range(T):
        mask_t = np.full(n_col, False, dtype=bool)
        n_unobserved_dimensions = rng.choice(
            n_col, 1, p=[0.6, 0.1, 0.1, 0.1, 0.1])
        unobserved_dimensions = rng.choice(
            n_col, n_unobserved_dimensions, replace=False)
        mask_t[unobserved_dimensions] = True
        mask.append(mask_t)
    return np.array(mask)


# new code component
def align_output_and_target(output_mean, output_var, targets, mask_targets):
    # removes last time point of output and first time point of target for one-step-ahead prediction
    output_mean = output_mean[:, :-1, ...]
    output_var = output_var[:, :-1, ...]
    targets = targets[:, 1:, ...]
    mask_targets = mask_targets[:, 1:, ...]
    return output_mean, output_var, targets, mask_targets

def adjust_obs_for_next_obs_pred(dataset, obs, obs_valid, mask_obs, mask_truth, truth,
        obs_times=None, cut_time=24):
    obs_valid_extrap = obs_valid#.clone()
    obs_extrap = obs#.clone()

    # zero out last half of observation (used for USHCN)
    if dataset == 'ushcn':
        assert False, "NYI"
        n_observed_time_points = obs.shape[1] // 2
        obs_valid_extrap[:, n_observed_time_points:, ...] = False
        obs_extrap[:, n_observed_time_points:, ...] = 0
        mask_obs[:, n_observed_time_points:, :] = False

    # zero out observations at > cut_time (used for Physionet)
    elif dataset == 'physionet':
        mask_before_cut_time = obs_times < cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()
        # trim all the future observations after the next observation
        last_indices = ((obs_times >= cut_time).cumsum(axis=1) == 1).float().argmax(dim=1) + 1
        #mask_truth = torch.where(last_indices, 
        last_index = last_indices.max().item()
        obs_extrap = obs_extrap[:,:last_index]
        mask_obs = mask_obs[:,:last_index]
        obs_valid_extrap = obs_valid_extrap[:,:last_index]
        mask_truth = mask_truth[:,:last_index]
        truth = truth[:,:last_index]
        obs_times = obs_times[:,:last_index]
        # unobserve anything after the next observation
        next_obs_problem_mask = (obs_times >= cut_time).cumsum(axis=1)<=1
        truth = torch.where(next_obs_problem_mask[:,:,None], truth, 0.)
        mask_truth = torch.where(next_obs_problem_mask[:,:,None], mask_truth, 0.).bool()

    elif dataset == 'mimic':
        assert False, "NYI"
        cut_time = 0
        mask_before_cut_time = obs_times <= cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()

    elif dataset == 'fBM':
        assert False, "NYI"
        cut_time = 0.5
        mask_before_cut_time = obs_times <= cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()

    return obs_extrap, obs_valid_extrap, mask_obs, mask_truth, truth, last_indices, obs_times


def adjust_obs_for_extrapolation(dataset, obs, obs_valid, mask_obs, obs_times=None, cut_time=24):
    obs_valid_extrap = obs_valid#.clone()
    obs_extrap = obs#.clone()

    # zero out last half of observation (used for USHCN)
    if dataset == 'ushcn':
        n_observed_time_points = obs.shape[1] // 2
        obs_valid_extrap[:, n_observed_time_points:, ...] = False
        obs_extrap[:, n_observed_time_points:, ...] = 0
        mask_obs[:, n_observed_time_points:, :] = False

    # zero out observations at > cut_time (used for Physionet)
    elif dataset == 'physionet':
        mask_before_cut_time = obs_times < cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()

    elif dataset == 'mimic':
        cut_time = 0
        mask_before_cut_time = obs_times <= cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()

    elif dataset == 'fBM':
        cut_time = 0.5
        mask_before_cut_time = obs_times <= cut_time
        obs_valid_extrap *= mask_before_cut_time
        obs_extrap = torch.where(obs_valid_extrap[:, :, None], obs_extrap, 0.)
        mask_obs = torch.where(obs_valid_extrap[:, :, None], mask_obs, 0.).bool()

    return obs_extrap, obs_valid_extrap, mask_obs


# new code component
def adjust_obs_for_extrapolation_old(obs, obs_valid, obs_times=None, cut_time=None):
    obs_valid_extrap = obs_valid.clone()
    obs_extrap = obs.clone()

    # zero out last half of observation (used for USHCN)
    if cut_time is None:
        n_observed_time_points = obs.shape[1] // 2
        obs_valid_extrap[:, n_observed_time_points:, ...] = False
        obs_extrap[:, n_observed_time_points:, ...] = 0

    # zero out observations at > cut_time (used for Physionet)
    else:
        mask_before_cut_time = obs_times < cut_time
        if obs_valid_extrap.ndim > mask_before_cut_time.ndim:
            obs_valid_extrap *= mask_before_cut_time[:,:,None]
            obs_extrap = torch.where(obs_valid_extrap, 
                obs_extrap, 0.)
        else:
            obs_valid_extrap *= mask_before_cut_time
            obs_extrap = torch.where(obs_valid_extrap[:, :, None], 
                obs_extrap, 0.)

    return obs_extrap, obs_valid_extrap

def compute_mean_inter_arrivals(train_dl):
    # compute the mean over the train dataset required for GRU-D
    print('computing the avg inter arrival ...')
    inter_arrivals = []
    for batch in train_dl:
        # observation is index 0
        times = batch[3]
        valid_time = batch[2]
        for batch_idx in range(batch[0].shape[0]):
            a_time = times[batch_idx]
            a_val_time = valid_time[batch_idx]
            inter_arrivals += a_time[a_val_time].diff().numpy().tolist()
    avg_inter_arrival = np.mean(inter_arrivals)
    print('done computing avg inter arrival period ...')
    return avg_inter_arrival

def compute_sample_period_per_channel(train_dl):
    # compute the mean over the train dataset required for GRU-D
    print('computing the sample period ...')
    sample_periods = defaultdict(list)
    mean_sample_period = []
    for batch in train_dl:
        # observation is index 0
        times = batch[3]
        for channel_idx in range(batch[0].shape[-1]):
            for batch_idx in range(batch[0].shape[0]):
                channel_mask = batch[4][batch_idx,:,channel_idx]
                if channel_mask.dtype != torch.bool: channel_mask = channel_mask.bool() 
                channel_times = times[batch_idx, channel_mask]
                sample_periods[channel_idx] += channel_times.diff().cpu().numpy().tolist()
    all_periods = [a_val for k, v in sample_periods.items() for a_val in v]
    avg_mean_period = np.mean(all_periods)
    for i in range(batch[0].shape[-1]):
        all_samples = sample_periods[i]
        if len(all_samples) == 0:
            print('Zero deltas for {} channel'.format(i))
            # append a default
            mean_sample_period.append(avg_mean_period)
            continue 
        mean_sample_period.append(np.mean(all_samples))
    print('done computing sample period ...')
    return mean_sample_period

ensure_lb_1 = lambda x: max(x, 1.0)



def means_on_physionet(train_dl):
    # compute the mean over the train dataset required for GRU-D
    print('computing the means ...')
    sums = []
    n_samples = []
    for batch in train_dl:
        # observation is index 0
        obs = batch[0].view(-1, batch[0].shape[-1])
        # observation mask is index 4
        mask = batch[4].view(-1, batch[0].shape[-1]).double()
        sums.append((obs * mask).nansum(axis=0)[None,:])
        n_samples.append(mask.nansum(axis=0)[None,:])
    sums = torch.cat(sums, axis=0)
    n_samples = torch.cat(n_samples, axis=0)
    means = sums.sum(axis=0) / n_samples.sum(axis=0)
    print('NaNs in means for {} items ... '.format(means.isnan().sum().item()))
    means = means.nan_to_num(nan=0.0)
    print('computing the means ... done!')
    return means

