"""
Code originates from GRUD_mean.ipynb from GitHub repository https://github.com/Han-JD/GRU-D.
"""

from baselines.Raindrop.code.baselines.models import GRUD
import torch
import numpy as np
import pandas as pd
import os
import math
import warnings
import itertools
import numbers
import torch.utils.data as utils
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, precision_score, recall_score, f1_score
from sklearn.metrics import average_precision_score
from baselines.Raindrop.code.baselines.utils_phy12 import random_sample


def one_hot(y_):
    # Function to encode output labels from number indexes
    # e.g.: [[5], [0], [3]] --> [[0, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]]
    y_ = y_.reshape(len(y_))

    y_ = [int(x) for x in y_]
    n_values = np.max(y_) + 1
    return np.eye(n_values)[np.array(y_, dtype=np.int32)]


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def data_dataloader(dataset, outcomes, upsampling_batch, batch_size, split_type, feature_removal_level, missing_ratio,
                    train_proportion=0.8, dev_proportion=0.1, dataset_name='P12'):
    # 80% train, 10% validation, 10% test

    if split_type == 'random':
        # np.random.seed(77)   # if you want the same permutation for each run
        # shuffle data
        permuted_idx = np.random.permutation(dataset.shape[0])
        dataset = dataset[permuted_idx]
        outcomes = outcomes[permuted_idx]

        train_index = int(np.floor(dataset.shape[0] * train_proportion))
        dev_index = int(np.floor(dataset.shape[0] * (train_proportion + dev_proportion)))

        # split dataset to train/dev/test set
        if upsampling_batch:
            train_data = []
            train_label = []
            idx_0 = np.where(outcomes[:train_index, :] == 0)[0]
            idx_1 = np.where(outcomes[:train_index, :] == 1)[0]
            
            # Handle case where we have fewer samples than batch_size
            if train_index < batch_size:
                # Use all available training data without upsampling
                train_data = dataset[:train_index, :, :, :]
                train_label = outcomes[:train_index, :]
            else:
                # Original upsampling logic
                for i in range(train_index // batch_size):
                    indices = random_sample(idx_0, idx_1, batch_size)
                    train_data.extend(dataset[indices, :, :, :])
                    train_label.extend(outcomes[indices, :])
                train_data = np.array(train_data)
                train_label = np.array(train_label)
        else:
            train_data, train_label = dataset[:train_index, :, :, :], outcomes[:train_index, :]

        dev_data, dev_label = dataset[train_index:dev_index, :, :, :], outcomes[train_index:dev_index, :]
        test_data, test_label = dataset[dev_index:, :, :, :], outcomes[dev_index:, :]
    elif split_type == 'age' or split_type == 'gender':
        # # calculate and save statistics
        # idx_under_65 = []
        # idx_over_65 = []
        # idx_male = []
        # idx_female = []
        #
        # P_list = np.load('../../P12data/processed_data/P_list.npy', allow_pickle=True)
        #
        # for i in range(len(P_list)):
        #     age, gender, _, _, _ = P_list[i]['static']
        #     if age > 0:
        #         if age < 65:
        #             idx_under_65.append(i)
        #         else:
        #             idx_over_65.append(i)
        #     if gender == 0:
        #         idx_female.append(i)
        #     if gender == 1:
        #         idx_male.append(i)
        #
        # np.save('saved/grud_idx_under_65.npy', np.array(idx_under_65), allow_pickle=True)
        # np.save('saved/grud_idx_over_65.npy', np.array(idx_over_65), allow_pickle=True)
        # np.save('saved/grud_idx_male.npy', np.array(idx_male), allow_pickle=True)
        # np.save('saved/grud_idx_female.npy', np.array(idx_female), allow_pickle=True)

        if dataset_name == 'P12':
            prefix = 'grud'
        elif dataset_name == 'P19':
            prefix = 'P19'
        elif dataset_name == 'eICU':    # not possible with split_type == 'age'
            prefix = 'eICU'

        if split_type == 'age':
            idx_train = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/' + prefix + '_idx_under_65.npy', allow_pickle=True)
            idx_vt = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/' + prefix + '_idx_over_65.npy', allow_pickle=True)
        else:   # split_type == 'gender':
            idx_train = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/' + prefix + '_idx_male.npy', allow_pickle=True)
            idx_vt = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/' + prefix + '_idx_female.npy', allow_pickle=True)

        if upsampling_batch:
            train_data = []
            train_label = []
            idx_0 = idx_train[np.where(outcomes[idx_train, :] == 0)[0]]
            idx_1 = idx_train[np.where(outcomes[idx_train, :] == 1)[0]]
            
            # Handle case where we have fewer samples than batch_size
            if len(idx_train) < batch_size:
                # Use all available training data without upsampling
                train_data = dataset[idx_train, :, :, :]
                train_label = outcomes[idx_train, :]
            else:
                # Original upsampling logic
                for i in range(len(idx_train) // batch_size):   # last small batch is dropped
                    indices = random_sample(idx_0, idx_1, batch_size)
                    train_data.extend(dataset[indices, :, :, :])
                    train_label.extend(outcomes[indices, :])
                train_data = np.array(train_data)
                train_label = np.array(train_label)
        else:
            train_data, train_label = dataset[idx_train, :, :, :], outcomes[idx_train, :]

        np.random.shuffle(idx_vt)
        idx_val = idx_vt[:round(len(idx_vt) / 2)]
        idx_test = idx_vt[round(len(idx_vt) / 2):]

        dev_data, dev_label = dataset[idx_val, :, :, :], outcomes[idx_val, :]
        test_data, test_label = dataset[idx_test, :, :, :], outcomes[idx_test, :]

    if feature_removal_level == 'sample':
        num_all_features = dev_data.shape[2]
        num_missing_features = round(missing_ratio * num_all_features)
        for i, patient in enumerate(dev_data):
            idx = np.random.choice(num_all_features, num_missing_features, replace=False)
            patient[:, idx, :] = np.zeros(shape=(dev_data.shape[1], num_missing_features, dev_data.shape[3]))
            dev_data[i] = patient
        for i, patient in enumerate(test_data):
            idx = np.random.choice(num_all_features, num_missing_features, replace=False)
            patient[:, idx, :] = np.zeros(shape=(test_data.shape[1], num_missing_features, test_data.shape[3]))
            test_data[i] = patient
    elif feature_removal_level == 'set':
        num_all_features = dev_data.shape[2]
        num_missing_features = round(missing_ratio * num_all_features)

        if dataset_name == 'P12':
            density_scores = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/IG_density_scores_P12.npy', allow_pickle=True)

            inputdict = {
                "ALP": 0,  # o
                "ALT": 1,  # o
                "AST": 2,  # o
                "Albumin": 3,  # o
                "BUN": 4,  # o
                "Bilirubin": 5,  # o
                "Cholesterol": 6,  # o
                "Creatinine": 7,  # o
                "DiasABP": 8,  # o
                "FiO2": 9,  # o
                "GCS": 10,  # o
                "Glucose": 11,  # o
                "HCO3": 12,  # o
                "HCT": 13,  # o
                "HR": 14,  # o
                "K": 15,  # o
                "Lactate": 16,  # o
                "MAP": 17,  # o
                "Mg": 18,  # o
                "Na": 19,  # o
                "PaCO2": 20,  # o
                "PaO2": 21,  # o
                "Platelets": 22,  # o
                "RespRate": 23,  # o
                "SaO2": 24,  # o
                "SysABP": 25,  # o
                "Temp": 26,  # o
                "Tropl": 27,  # o
                "TroponinI": 27,  # temp: regarded same as Tropl
                "TropT": 28,  # o
                "TroponinT": 28,  # temp: regarded same as TropT
                "Urine": 29,  # o
                "WBC": 30,  # o
                "Weight": 31,  # o
                "pH": 32,  # o
                "NIDiasABP": 33,  # unused variable
                "NIMAP": 34,  # unused variable
                "NISysABP": 35,  # unused variable
                "MechVent": 36,  # unused variable
                "RecordID": 37,  # unused variable
                "Age": 38,  # unused variable
                "Gender": 39,  # unused variable
                "ICUType": 40,  # unused variable
                "Height": 41  # unused variable
            }
            idx = []
            for _, name in density_scores:
                if inputdict[name] < 33:
                    idx.append(inputdict[name])
            idx = list(set(idx[:num_missing_features]))
        elif dataset_name == 'P19':
            density_scores = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/IG_density_scores_P19.npy', allow_pickle=True)
            idx = list(map(int, density_scores[:, 0][:num_missing_features]))
        elif dataset_name == 'eICU':
            density_scores = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/IG_density_scores_eICU.npy', allow_pickle=True)
            idx = list(map(int, density_scores[:, 0][:num_missing_features]))
        elif dataset_name == 'PAM':
            density_scores = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/IG_density_scores_PAM.npy', allow_pickle=True)
            idx = list(map(int, density_scores[:, 0][:num_missing_features]))

        dev_data[:, :, idx, :] = np.zeros(shape=(dev_data.shape[0], dev_data.shape[1], len(idx), dev_data.shape[3]))
        test_data[:, :, idx, :] = np.zeros(shape=(test_data.shape[0], test_data.shape[1], len(idx), test_data.shape[3]))

    # ndarray to tensor
    train_data, train_label = torch.Tensor(train_data), torch.Tensor(train_label)
    dev_data, dev_label = torch.Tensor(dev_data), torch.Tensor(dev_label)
    test_data, test_label = torch.Tensor(test_data), torch.Tensor(test_label)
    
    # tensor to dataset
    train_dataset = utils.TensorDataset(train_data, train_label)
    dev_dataset = utils.TensorDataset(dev_data, dev_label)
    test_dataset = utils.TensorDataset(test_data, test_label)
    
    # dataset to dataloader 
    train_dataloader = utils.DataLoader(train_dataset)
    dev_dataloader = utils.DataLoader(dev_dataset)
    test_dataloader = utils.DataLoader(test_dataset)
    
    print("train_data.shape : {}\t train_label.shape : {}".format(train_data.shape, train_label.shape))
    print("dev_data.shape : {}\t dev_label.shape : {}".format(dev_data.shape, dev_label.shape))
    print("test_data.shape : {}\t test_label.shape : {}".format(test_data.shape, test_label.shape))
    
    return train_dataloader, dev_dataloader, test_dataloader


def train_gru_d(num_runs, input_size, hidden_size, output_size, num_layers, dropout, learning_rate, n_epochs,
                batch_size, upsampling_batch, split_type, feature_removal_level, missing_ratio, dataset, predictive_label='mortality'):
    model_path = '/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/grud_model_best.pt'

    acc_all = []
    auc_all = []
    aupr_all = []
    precision_all = []
    recall_all = []
    F1_all = []

    # Progress bar for runs
    run_pbar = tqdm(range(num_runs), desc="Processing runs", unit="run")
    for r in run_pbar:
        if dataset == 'CD':
            # For CD dataset, we need to use the same data loading approach as other baselines
            # This will be handled differently since CD doesn't have pre-saved numpy files
            from baselines.Raindrop.code.baselines.utils_phy12 import get_data_split
            base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/CDdata'
            split_path = '/splits/cd_split1.npy'  # Use first split for now
            
            Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(base_path, split_path, split_type='random',
                                                                      reverse=False, baseline=True, dataset='CD',
                                                                      predictive_label='mortality')
            
            # Quick test mode: limit to 100 samples each
            if args.quick_test:
                max_samples = 100
                if len(Ptrain) > max_samples:
                    print(f"Quick test mode: Limiting data to {max_samples} samples each")
                    Ptrain = Ptrain[:max_samples]
                    Pval = Pval[:max_samples]
                    Ptest = Ptest[:max_samples]
                    ytrain = ytrain[:max_samples]
                    yval = yval[:max_samples]
                    ytest = ytest[:max_samples]
                    print(f"After limiting: train={len(Ptrain)}, val={len(Pval)}, test={len(Ptest)}")
            
            # Convert CD data format to GRU-D format
            # CD data is in PTdict_list format where each patient has 'arr' (time series) and 'time' keys
            all_patients = list(Ptrain) + list(Pval) + list(Ptest)
            all_labels = np.concatenate([ytrain, yval, ytest], axis=0)
            
            # Get dimensions from first patient
            max_len = all_patients[0]['arr'].shape[0]  # time steps
            num_features = all_patients[0]['arr'].shape[1]  # features
            
            # Create GRU-D format: tuple of (X, Mask, Delta)
            # X: [samples, time_steps, features] - the actual values
            # Mask: [samples, time_steps, features] - 1 where observed, 0 where missing
            # Delta: [samples, time_steps, features] - time intervals between observations
            X = np.zeros((len(all_patients), max_len, num_features))
            Mask = np.zeros((len(all_patients), max_len, num_features))
            Delta = np.zeros((len(all_patients), max_len, num_features))
            
            for i, patient in enumerate(all_patients):
                # Extract time series data
                patient_arr = patient['arr']  # shape: [time_steps, features]
                patient_time = patient['time']  # shape: [time_steps, 1]
                
                # Set values
                X[i, :, :] = patient_arr
                
                # Set masks - 1 where data exists (non-zero), 0 where missing
                Mask[i, :, :] = (patient_arr != 0).astype(float)
                
                # Calculate time deltas
                time_deltas = np.zeros((max_len, num_features))
                for t in range(1, max_len):
                    if patient_time[t, 0] > 0 and patient_time[t-1, 0] > 0:
                        delta = patient_time[t, 0] - patient_time[t-1, 0]
                        time_deltas[t, :] = delta
                Delta[i, :, :] = time_deltas
            
            # Create the tuple format expected by GRU-D
            t_dataset = (X, Mask, Delta)
            t_out = all_labels
        elif dataset == 'P12':
            use_cached_env = os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')
            if use_cached_env:
                # Use cached PSV dataset via get_data_split (env provides cached dir and split pkl)
                from baselines.Raindrop.code.baselines.utils_phy12 import get_data_split
                base_path = '/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/P12data'
                split_path = '/splits/phy12_split1.npy'
                Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(base_path, split_path, split_type='random',
                                                                          reverse=False, baseline=True, dataset='P12',
                                                                          predictive_label=predictive_label)
                # Convert to GRU-D tuple format (X, Mask, Delta) per set
                def build_tuple(P_list):
                    if len(P_list) == 0:
                        return torch.zeros(1, 1, 1), torch.zeros(1, 1, 1), torch.zeros(1, 1, 1)
                    max_len_local = max(p['arr'].shape[0] for p in P_list)
                    num_features_local = P_list[0]['arr'].shape[1]
                    X_local = np.zeros((len(P_list), max_len_local, num_features_local))
                    Mask_local = np.zeros((len(P_list), max_len_local, num_features_local))
                    Delta_local = np.zeros((len(P_list), max_len_local, num_features_local))
                    for i_p, p in enumerate(P_list):
                        arr_i = p['arr']
                        t_i = len(arr_i)
                        X_local[i_p, :t_i, :] = arr_i
                        Mask_local[i_p, :t_i, :] = (arr_i != 0).astype(float)
                        # time deltas in minutes
                        time_vec = p['time'].reshape(-1)
                        for t in range(1, t_i):
                            dt = max(0.0, float(time_vec[t] - time_vec[t-1]))
                            Delta_local[i_p, t, :] = dt
                    return torch.Tensor(X_local), torch.Tensor(Mask_local), torch.Tensor(Delta_local)
                X_tr, M_tr, D_tr = build_tuple(Ptrain)
                X_val, M_val, D_val = build_tuple(Pval)
                X_te, M_te, D_te = build_tuple(Ptest)
                # x_mean from training data non-zero values per feature
                train_values = X_tr.numpy()
                x_mean_vec = []
                for feat in range(train_values.shape[2]):
                    vals = train_values[:, :, feat]
                    non_zero = vals[vals != 0]
                    x_mean_vec.append(non_zero.mean() if non_zero.size > 0 else 0.0)
                x_mean = torch.tensor(x_mean_vec, dtype=torch.float32)
                # Override model dimensions to match cached tensors
                # Features dimension -> input_size and hidden_size should match
                input_size = int(X_tr.shape[2])
                hidden_size = int(input_size)
                # Time dimension / layers (sequence length)
                num_layers = int(X_tr.shape[1])
                # GRU-D expects tensors shaped [features, time]; permute from [time, features]
                X_tr = X_tr.permute(0, 2, 1)
                M_tr = M_tr.permute(0, 2, 1)
                D_tr = D_tr.permute(0, 2, 1)
                X_val = X_val.permute(0, 2, 1)
                M_val = M_val.permute(0, 2, 1)
                D_val = D_val.permute(0, 2, 1)
                X_te = X_te.permute(0, 2, 1)
                M_te = M_te.permute(0, 2, 1)
                D_te = D_te.permute(0, 2, 1)
                # Now create dataloaders
                train_dataloader = utils.DataLoader(utils.TensorDataset(X_tr, M_tr, D_tr, torch.Tensor(ytrain)))
                dev_dataloader = utils.DataLoader(utils.TensorDataset(X_val, M_val, D_val, torch.Tensor(yval)))
                test_dataloader = utils.DataLoader(utils.TensorDataset(X_te, M_te, D_te, torch.Tensor(ytest)))
                n_classes = 2
                if r == 0:
                    print(f"Cached P12 shapes - X: {X_tr.shape}, Mask: {M_tr.shape}, Delta: {D_tr.shape}, Labels: {torch.Tensor(ytrain).shape}")
                # Short-circuit the standard loader creation path
                t_dataset = None
                t_out = None
            else:
                t_dataset = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/dataset.npy')
                if predictive_label == 'mortality':
                    t_out = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/y1_out.npy')
                elif predictive_label == 'LoS':  # for P12 only
                    t_out = np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/LoS_y1_out.npy')
        elif dataset == 'P19':
            t_dataset = np.load('saved/P19_dataset.npy')
            t_out = np.load('saved/P19_y1_out.npy')
        elif dataset == 'eICU':
            t_dataset = np.load('saved/eICU_dataset.npy')
            t_out = np.load('saved/eICU_y1_out.npy')
        elif dataset == 'PAM':
            t_dataset = np.load('saved/PAM_dataset.npy') 
            t_out = np.load('saved/PAM_y1_out.npy')

        if r == 0:
            if dataset == 'CD':
                X, Mask, Delta = t_dataset
                print(f"CD dataset shapes - X: {X.shape}, Mask: {Mask.shape}, Delta: {Delta.shape}, Labels: {t_out.shape}")
            else:
                # When using cached P12 PSV path, t_dataset is intentionally None
                if t_dataset is not None and t_out is not None:
                    print(t_dataset.shape, t_out.shape)

        if dataset == 'CD':
            # For CD dataset, handle the tuple format directly
            X, Mask, Delta = t_dataset
            
            # Split the data according to train/val/test proportions
            train_size = int(0.8 * len(X))
            val_size = int(0.1 * len(X))
            
            # Split X, Mask, Delta, and labels
            train_X = torch.Tensor(X[:train_size])
            train_Mask = torch.Tensor(Mask[:train_size])
            train_Delta = torch.Tensor(Delta[:train_size])
            train_labels = torch.Tensor(t_out[:train_size])
            
            val_X = torch.Tensor(X[train_size:train_size+val_size])
            val_Mask = torch.Tensor(Mask[train_size:train_size+val_size])
            val_Delta = torch.Tensor(Delta[train_size:train_size+val_size])
            val_labels = torch.Tensor(t_out[train_size:train_size+val_size])
            
            test_X = torch.Tensor(X[train_size+val_size:])
            test_Mask = torch.Tensor(Mask[train_size+val_size:])
            test_Delta = torch.Tensor(Delta[train_size+val_size:])
            test_labels = torch.Tensor(t_out[train_size+val_size:])
            
            # Create datasets
            train_dataset = utils.TensorDataset(train_X, train_Mask, train_Delta, train_labels)
            dev_dataset = utils.TensorDataset(val_X, val_Mask, val_Delta, val_labels)
            test_dataset = utils.TensorDataset(test_X, test_Mask, test_Delta, test_labels)
            
            # Create dataloaders
            train_dataloader = utils.DataLoader(train_dataset)
            dev_dataloader = utils.DataLoader(dev_dataset)
            test_dataloader = utils.DataLoader(test_dataset)
            
            print("train_data.shape : {}\t train_label.shape : {}".format(train_X.shape, train_labels.shape))
            print("dev_data.shape : {}\t dev_label.shape : {}".format(val_X.shape, val_labels.shape))
            print("test_data.shape : {}\t test_label.shape : {}".format(test_X.shape, test_labels.shape))
        else:
            if not (dataset == 'P12' and os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
                train_dataloader, dev_dataloader, test_dataloader = data_dataloader(t_dataset, t_out, upsampling_batch, batch_size,
                                                                                    split_type, feature_removal_level, missing_ratio,
                                                                                    train_proportion=0.8, dev_proportion=0.1,
                                                                                    dataset_name=dataset)
        if dataset == 'P12':
            if os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes'):
                # x_mean already computed above for cached path
                n_classes = 2
            else:
                x_mean = torch.Tensor(np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/x_mean_aft_nor.npy'))
                n_classes = 2
        elif dataset == 'P19':
            x_mean = torch.Tensor(np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/P19_x_mean_aft_nor.npy'))
            n_classes = 2
        elif dataset == 'eICU':
            x_mean = torch.Tensor(np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/eICU_x_mean_aft_nor.npy'))
            n_classes = 2
        elif dataset == 'PAM':
            x_mean = torch.Tensor(np.load('/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/code/baselines/saved/PAM_x_mean_aft_nor.npy'))
            n_classes = 8
        elif dataset == 'CD':
            # For CD dataset, calculate x_mean from the training data
            # Use the X values from the tuple
            X, _, _ = t_dataset
            train_data_values = X[:int(0.8 * len(X))]  # 80% for training
            # Calculate mean across all non-zero values for each feature
            x_mean = torch.zeros(input_size)
            for i in range(input_size):
                feature_values = train_data_values[:, :, i]
                non_zero_values = feature_values[feature_values != 0]
                if len(non_zero_values) > 0:
                    x_mean[i] = torch.tensor(non_zero_values.mean())
            n_classes = 2
        print(x_mean.shape)

        model = GRUD(input_size=input_size, hidden_size=hidden_size, output_size=output_size, dropout=dropout,
                     dropout_type='mloss', x_mean=x_mean, num_layers=num_layers)

        epoch_losses = []

        old_state_dict = {}
        for key in model.state_dict():
            old_state_dict[key] = model.state_dict()[key].clone()

        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
            criterion = torch.nn.BCELoss()
        elif dataset == 'PAM':
            criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
                                                               patience=1, threshold=0.0001, threshold_mode='rel',
                                                               cooldown=0, min_lr=1e-8, eps=1e-08)

        print('\n------------------\nRUN %d: Training started\n------------------' % r)
        best_aupr_val = 0
        
        # Progress bar for epochs
        epoch_pbar = tqdm(range(n_epochs), desc=f"Training (Run {r+1})", unit="epoch")
        for epoch in epoch_pbar:
            # train the model
            losses, acc = [], []
            label, pred = [], []
            y_pred_col = []
            model.train()
            for data_batch in train_dataloader:
                # Zero the parameter gradients
                optimizer.zero_grad()

                if dataset == 'CD' or (dataset == 'P12' and os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
                    # CD and cached P12 datasets return (X, Mask, Delta, label)
                    train_X, train_Mask, train_Delta, train_label = data_batch
                    # Keep batch dimension for GRUD model which expects [batch, time, features]
                    train_data = (train_X, train_Mask, train_Delta)
                    train_label = torch.squeeze(train_label, dim=0)
                else:
                    # Other datasets return (data, label)
                    train_data, train_label = data_batch
                    train_data = torch.squeeze(train_data)
                    train_label = torch.squeeze(train_label, dim=0)

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    y_pred = model(train_data)

                    # Save predict and label
                    y_pred_col.append(y_pred.item())
                    pred.append(y_pred.item() > 0.5)
                    label.append(train_label.item())

                    # Compute loss
                    loss = criterion(y_pred, train_label)
                    acc.append(
                        torch.eq(
                            (y_pred.data > 0.5).float(),
                            train_label)
                    )
                    losses.append(loss.item())
                elif dataset == 'PAM':
                    y_pred = model(train_data, dataset_name=dataset)

                    # Save predict and label
                    y_pred_col.append(torch.argmax(y_pred).item())
                    label.append(train_label.item())

                    # Compute loss
                    loss = criterion(torch.unsqueeze(y_pred, 0), train_label.type(torch.LongTensor))

                    acc.append(
                        torch.eq(
                            torch.argmax(y_pred),
                            train_label)
                    )
                    losses.append(loss.item())

                loss.backward()
                optimizer.step()

            train_acc = torch.mean(torch.cat(acc).float())
            train_loss = np.mean(losses)

            train_pred_out = pred
            train_label_out = label

            # save new params
            new_state_dict = {}
            for key in model.state_dict():
                new_state_dict[key] = model.state_dict()[key].clone()

            # compare params
            for key in old_state_dict:
                if (old_state_dict[key] == new_state_dict[key]).all():
                    print('Not updated in {}'.format(key))

            # validation loss
            losses, acc = [], []
            label, pred = [], []
            model.eval()
            for data_batch in dev_dataloader:
                if dataset == 'CD' or (dataset == 'P12' and os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
                    # CD and cached P12 datasets return (X, Mask, Delta, label)
                    dev_X, dev_Mask, dev_Delta, dev_label = data_batch
                    # Keep batch dimension for GRUD model which expects [batch, time, features]
                    dev_data = (dev_X, dev_Mask, dev_Delta)
                    dev_label = torch.squeeze(dev_label, dim=0)
                else:
                    # Other datasets return (data, label)
                    dev_data, dev_label = data_batch
                    dev_data = torch.squeeze(dev_data)
                    dev_label = torch.squeeze(dev_label, dim=0)

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    y_pred = model(dev_data)

                    # Save predict and label
                    pred.append(y_pred.item())
                    label.append(dev_label.item())

                    # Compute loss
                    loss = criterion(y_pred, dev_label)
                    acc.append(
                        torch.eq(
                            (y_pred.data > 0.5).float(),
                            dev_label)
                    )

                    losses.append(loss.item())
                elif dataset == 'PAM':
                    y_pred = model(dev_data, dataset_name=dataset)

                    # Save predict and label
                    pred.append(torch.argmax(y_pred).item())
                    label.append(dev_label.item())

                    # Compute loss
                    loss = criterion(torch.unsqueeze(y_pred, 0), dev_label.type(torch.LongTensor))

                    acc.append(
                        torch.eq(
                            torch.argmax(y_pred),
                            dev_label)
                    )
                    losses.append(loss.item())

            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU':
                auc_val = roc_auc_score(label, pred)
                aupr_val = average_precision_score(label, pred)
            elif dataset == 'PAM':
                label_oh = np.array(label)
                label_oh = one_hot(label_oh[:, np.newaxis])
                pred_oh = np.array(pred)
                pred_oh = one_hot(pred_oh[:, np.newaxis])
                auc_val = roc_auc_score(label_oh, pred_oh)
                aupr_val = average_precision_score(label_oh, pred_oh)

            scheduler.step(aupr_val)  # reduce learning rate when this metric has stopped improving

            if aupr_val > best_aupr_val:
                best_aupr_val = aupr_val
                torch.save(model, model_path)

            dev_acc = torch.mean(torch.cat(acc).float())
            dev_loss = np.mean(losses)

            dev_pred_out = pred
            dev_label_out = label

            print("VALIDATION: Epoch %d, val_acc: %.2f, val_loss: %.2f, aupr_val: %.2f, auc_val: %.2f" %
                  (epoch, dev_acc * 100, dev_loss.item(), aupr_val * 100, auc_val * 100))

            # if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU':
            #     print(confusion_matrix(label, (np.array(pred) > 0.5).astype(int), labels=list(range(n_classes))))
            # elif dataset == 'PAM':
            #     print(confusion_matrix(label, pred, labels=list(range(n_classes))))

            # test loss
            losses, acc = [], []
            label, pred = [], []
            model.eval()
            for data_batch in test_dataloader:
                if dataset == 'CD' or (dataset == 'P12' and os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
                    # CD and cached P12 datasets return (X, Mask, Delta, label)
                    test_X, test_Mask, test_Delta, test_label = data_batch
                    # Keep batch dimension for GRUD model which expects [batch, time, features]
                    test_data = (test_X, test_Mask, test_Delta)
                    test_label = torch.squeeze(test_label, dim=0)
                else:
                    # Other datasets return (data, label)
                    test_data, test_label = data_batch
                    test_data = torch.squeeze(test_data)
                    test_label = torch.squeeze(test_label, dim=0)

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    y_pred = model(test_data)

                    # Save predict and label
                    pred.append(y_pred.item())
                    label.append(test_label.item())

                    # Compute loss
                    loss = criterion(y_pred, test_label)
                    acc.append(
                        torch.eq(
                            (y_pred.data > 0.5).float(),
                            test_label)
                    )
                    losses.append(loss.item())
                elif dataset == 'PAM':
                    y_pred = model(test_data, dataset_name=dataset)

                    # Save predict and label
                    pred.append(torch.argmax(y_pred).item())
                    label.append(test_label.item())

                    # Compute loss
                    loss = criterion(torch.unsqueeze(y_pred, 0), test_label.type(torch.LongTensor))

                    acc.append(
                        torch.eq(
                            torch.argmax(y_pred),
                            test_label)
                    )
                    losses.append(loss.item())

            test_acc = torch.mean(torch.cat(acc).float())
            test_loss = np.mean(losses)

            test_pred_out = pred
            test_label_out = label

            epoch_losses.append([
                 train_loss, dev_loss, test_loss,
                 train_acc, dev_acc, test_acc,
                 train_pred_out, dev_pred_out, test_pred_out,
                 train_label_out, dev_label_out, test_label_out,
             ])

        print('\n------------------\nRUN %d: Training finished\n------------------' % r)

        # Test set
        # In PyTorch >= 2.6, torch.load defaults to weights_only=True which breaks loading full models saved via torch.save(model, ...)
        # Explicitly disable weights_only to allow loading the full serialized model object.
        model = torch.load(model_path, weights_only=False)

        losses, acc = [], []
        label, pred = [], []
        model.eval()
        for data_batch in test_dataloader:
            if dataset == 'CD' or (dataset == 'P12' and os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
                # CD and cached P12 datasets return (X, Mask, Delta, label)
                test_X, test_Mask, test_Delta, test_label = data_batch
                # Keep batch dimension for GRUD model which expects [batch, time, features]
                test_data = (test_X, test_Mask, test_Delta)
                test_label = torch.squeeze(test_label, dim=0)
            else:
                # Other datasets return (data, label)
                test_data, test_label = data_batch
                test_data = torch.squeeze(test_data)
                test_label = torch.squeeze(test_label, dim=0)

            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                y_pred = model(test_data)

                # Save predict and label
                pred.append(y_pred.item())
                label.append(test_label.item())

                # Compute loss
                loss = criterion(y_pred, test_label)
                acc.append(
                    torch.eq(
                        (y_pred.data > 0.5).float(),
                        test_label)
                )
                losses.append(loss.item())
            elif dataset == 'PAM':
                y_pred = model(test_data, dataset_name=dataset)

                # Save predict and label
                pred.append(torch.argmax(y_pred).item())
                label.append(test_label.item())

                # Compute loss
                loss = criterion(torch.unsqueeze(y_pred, 0), test_label.type(torch.LongTensor))

                acc.append(
                    torch.eq(
                        torch.argmax(y_pred),
                        test_label)
                )
                losses.append(loss.item())

        test_acc = torch.mean(torch.cat(acc).float())
        test_loss = np.mean(losses)

        pred = np.asarray(pred)
        label = np.asarray(label)

        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU':
            auc_score = roc_auc_score(label, pred)
            aupr_score = average_precision_score(label, pred)
        elif dataset == 'PAM':
            label_oh = np.array(label)
            label_oh = one_hot(label_oh[:, np.newaxis])
            pred_oh = np.array(pred)
            pred_oh = one_hot(pred_oh[:, np.newaxis])
            auc_score = roc_auc_score(label_oh, pred_oh)
            aupr_score = average_precision_score(label_oh, pred_oh)
            precision = precision_score(label, pred, average='macro', )
            recall = recall_score(label, pred, average='macro', )
            F1_score = f1_score(label, pred, average='macro', )

            print("\nTEST: test_precision: %.2f test_recall: %.2f, test_F1: %.2f\n" %
                  (precision * 100, recall * 100, F1_score * 100))

        print("\nTEST: test_acc: %.2f aupr_test: %.2f, auc_test: %.2f\n" %
              (test_acc * 100, aupr_score * 100, auc_score * 100))

        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU':
            print(confusion_matrix(label, (np.array(pred) > 0.5).astype(int), labels=list(range(n_classes))))
        elif dataset == 'PAM':
            print(confusion_matrix(label, pred, labels=list(range(n_classes))))

        acc_all.append(test_acc * 100)
        auc_all.append(auc_score * 100)
        aupr_all.append(aupr_score * 100)
        if dataset == 'PAM':
            precision_all.append(precision * 100)
            recall_all.append(recall * 100)
            F1_all.append(F1_score * 100)

    # print mean and std of all metrics
    acc_all, auc_all, aupr_all = np.array(acc_all), np.array(auc_all), np.array(aupr_all)
    mean_acc, std_acc = np.mean(acc_all), np.std(acc_all)
    mean_auc, std_auc = np.mean(auc_all), np.std(auc_all)
    mean_aupr, std_aupr = np.mean(aupr_all), np.std(aupr_all)
    print('------------------------------------------')
    print('Accuracy = %.1f +/- %.1f' % (mean_acc, std_acc))
    print('AUROC    = %.1f +/- %.1f' % (mean_auc, std_auc))
    print('AUPRC    = %.1f +/- %.1f' % (mean_aupr, std_aupr))
    if dataset == 'PAM':
        precision_all, recall_all, F1_all = np.array(precision_all), np.array(recall_all), np.array(F1_all)
        mean_precision, std_precision = np.mean(precision_all), np.std(precision_all)
        mean_recall, std_recall = np.mean(recall_all), np.std(recall_all)
        mean_F1, std_F1 = np.mean(F1_all), np.std(F1_all)
        print('Precision = %.1f +/- %.1f' % (mean_precision, std_precision))
        print('Recall    = %.1f +/- %.1f' % (mean_recall, std_recall))
        print('F1        = %.1f +/- %.1f' % (mean_F1, std_F1))

    # #show AUROC on test data for last trained epoch
    # test_preds, test_labels = epoch_losses[-1][8], epoch_losses[-1][11]
    # plot_roc_and_auc_score(test_preds, test_labels, 'GRU-D')


def plot_roc_and_auc_score(outputs, labels, title):
    false_positive_rate, true_positive_rate, threshold = roc_curve(labels, outputs)
    auc_score = roc_auc_score(labels, outputs)
    plt.plot(false_positive_rate, true_positive_rate, label='ROC curve, AREA = {:.4f}'.format(auc_score))
    plt.plot([0,1], [0,1], 'red')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.axis([0, 1, 0, 1])
    plt.title(title)
    plt.legend(loc='lower right')
    plt.show()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='P12', choices=['P12', 'P19', 'eICU', 'PAM', 'CD'])
    parser.add_argument('--withmissingratio', default=False,
                        help='if True, missing ratio ranges from 0 to 0.5; if False, missing ratio =0')  #
    parser.add_argument('--splittype', type=str, default='random', choices=['random', 'age', 'gender'],
                        help='only use for P12 and P19')
    parser.add_argument('--reverse', default=False,
                        help='if True,use female, older for tarining; if False, use female or younger for training')  #
    parser.add_argument('--feature_removal_level', type=str, default='no_removal',
                        choices=['no_removal', 'set', 'sample'],
                        help='use this only when splittype==random; otherwise, set as no_removal')  #
    parser.add_argument('--predictive_label', type=str, default='mortality', choices=['mortality', 'LoS'],
                        help='use this only with P12 dataset (mortality or length of stay)')
    parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
    parser.add_argument('--quick_test', action='store_true', help='Run with only 100 samples for quick testing')
    parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
    # Cached PSV options (P12 only)
    parser.add_argument('--use_cached_dataset', action='store_true', help='Use cached PSV dataset (P12 only)')
    parser.add_argument('--cached_dataset_dir', type=str, default='/tmp', help='Directory with cached PSV files (P12)')
    parser.add_argument('--split_pkl_path', type=str, default='P12_data_splits/split_1.pkl', help='Path to split_*.pkl file')
    parser.add_argument('--los_threshold_days', type=int, default=3, help='LoS threshold in days for P12 cached dataset')
    args = parser.parse_args()
    print('Dataset used: ', args.dataset)   # possible values: 'P12', 'P19', 'eICU', 'PAM'

    # Set random seeds for reproducibility
    import random
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f'Using seed: {args.seed}')

    # Device detection
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    if args.dataset == 'P12':
        input_size = 33  # num of variables base on the paper
        hidden_size = 33  # same as inputsize
    elif args.dataset == 'P19':
        input_size = 40
        hidden_size = 40
    elif args.dataset == 'eICU':
        input_size = 16
        hidden_size = 16
    elif args.dataset == 'PAM':
        input_size = 17
        hidden_size = 17
    elif args.dataset == 'CD':
        input_size = 17  # 17 biomarkers
        hidden_size = 17  # same as inputsize

    if args.withmissingratio == True:
        missing_ratios = [0.1, 0.2, 0.3, 0.4, 0.5]  # if True, with missing ratio
    else:
        missing_ratios = [0]

    for missing_ratio in missing_ratios:
        num_runs = 1
        if args.dataset == 'P12' or args.dataset == 'P19' or args.dataset == 'eICU' or args.dataset == 'CD':
            output_size = 1
        elif args.dataset == 'PAM':
            output_size = 8
        num_layers = 49  # num of step or layers base on the paper / number of hidden states
        dropout = 0.1  # dropout_type : Moon, Gal, mloss
        learning_rate = 0.001
        n_epochs = args.epochs
        batch_size = 16
        if args.dataset == 'P12' or args.dataset == 'P19' or args.dataset == 'eICU' or args.dataset == 'CD':
            upsampling_batch = True
        elif args.dataset == 'PAM':
            upsampling_batch = False

        split_type = args.splittype  # possible values: 'random', 'age', 'gender'
        reverse_ = args.reverse  # False or True
        feature_removal_level = args.feature_removal_level  # possible values: 'sample', 'set'

        # If using cached PSV dataset for P12, set environment variables expected by get_data_split
        if args.dataset == 'P12' and getattr(args, 'use_cached_dataset', False):
            os.environ['USE_CACHED_DATASET'] = '1'
            os.environ['CACHED_PSV_DIR'] = str(getattr(args, 'cached_dataset_dir', '/tmp'))
            os.environ['SPLIT_PKL_PATH'] = str(getattr(args, 'split_pkl_path', 'P12_data_splits/split_1.pkl'))
            os.environ['LOS_THRESHOLD_DAYS'] = str(getattr(args, 'los_threshold_days', 3))

        train_gru_d(num_runs, input_size, hidden_size, output_size, num_layers, dropout, learning_rate, n_epochs,
                    batch_size, upsampling_batch, split_type, feature_removal_level, missing_ratio, args.dataset, args.predictive_label)
