
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import time

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score


def _load_cached_p12_from_psv(cached_dir: str, split_pkl_path: str, predictive_label: str = 'mortality', los_threshold_days: int = 3):
    """Build P12 data lists from cached PSV files with 'newlabel' column.

    Returns Ptrain, Pval, Ptest, ytrain, yval, ytest compatible with Raindrop.
    """
    import pickle as pkl

    # Biomarker order must match Raindrop expectations (36 features)
    biomarker_features = [
        'ALP', 'ALT', 'AST', 'Albumin', 'BUN', 'Bilirubin', 'Cholesterol', 'Creatinine',
        'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'MAP',
        'MechVent', 'Mg', 'NIDiasABP', 'NIMAP', 'NISysABP', 'Na', 'PaCO2', 'PaO2',
        'Platelets', 'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT',
        'Urine', 'WBC', 'pH'
    ]

    with open(split_pkl_path, 'rb') as f:
        data_split = pkl.load(f)

    def map_to_cached(psv_paths):
        mapped = []
        for rel_path in psv_paths:
            try:
                rid = int(os.path.splitext(os.path.basename(rel_path))[0])
            except Exception:
                continue
            cand = os.path.join(cached_dir, f"{rid}.psv")
            if os.path.exists(cand):
                mapped.append((rid, cand))
        return mapped

    train_files = map_to_cached(data_split['train_files'])
    val_files = map_to_cached(data_split['val_files'])
    test_files = map_to_cached(data_split['test_files'])

    def build_lists(rid_and_paths):
        P_list = []
        y_list = []
        for rid, path in rid_and_paths:
            try:
                df = pd.read_csv(path, sep='|')
            except Exception:
                continue

            # Time in minutes
            # IMPORTANT: For LoS prediction, using absolute ICU length (ICULOS) as time leaks the target.
            # To prevent leakage, when predictive_label == 'LoS' we only encode ordering (equally spaced steps).
            # if predictive_label == 'LoS':
            #     time_minutes = np.arange(len(df), dtype=float) * 60.0
            # else:
            time_hours = df.get('ICULOS', pd.Series([0.0] * len(df))).astype(float).to_numpy()
            time_minutes = time_hours * 60.0

            # Build arr [T,F] in biomarker_features order, fill missing as 0
            arr = np.zeros((len(df), len(biomarker_features)), dtype=float)
            for j, biom in enumerate(biomarker_features):
                if biom in df.columns:
                    col = pd.to_numeric(df[biom], errors='coerce').to_numpy()
                    col = np.where(np.isnan(col), 0.0, col)
                    arr[:, j] = col
                else:
                    arr[:, j] = 0.0

            # Extended static [Age, Gender=0, Gender=1, Height, ICUType=1..4, Weight]
            age = float(pd.to_numeric(df.get('Age', pd.Series([0])), errors='coerce').dropna().iloc[0]) if 'Age' in df.columns and df['Age'].notna().any() else 0.0
            gender_val = pd.to_numeric(df.get('Gender', pd.Series([np.nan])), errors='coerce').dropna()
            if len(gender_val) > 0:
                gv = int(gender_val.iloc[0])
                g0, g1 = (1.0, 0.0) if gv == 0 else ((0.0, 1.0) if gv == 1 else (0.0, 0.0))
            else:
                g0, g1 = 0.0, 0.0
            height = float(pd.to_numeric(df.get('Height', pd.Series([0])), errors='coerce').dropna().iloc[0]) if 'Height' in df.columns and df['Height'].notna().any() else 0.0
            icu_type_val = pd.to_numeric(df.get('ICUType', pd.Series([np.nan])), errors='coerce').dropna()
            icu_one_hot = [0.0, 0.0, 0.0, 0.0]
            if len(icu_type_val) > 0:
                iv = int(icu_type_val.iloc[0])
                if 1 <= iv <= 4:
                    icu_one_hot[iv - 1] = 1.0
            weight = float(pd.to_numeric(df.get('Weight', pd.Series([0])), errors='coerce').dropna().iloc[0]) if 'Weight' in df.columns and df['Weight'].notna().any() else 0.0
            extended_static = np.array([age, g0, g1, height, *icu_one_hot, weight], dtype=float)

            # Label
            if predictive_label == 'LoS':
                if 'newlabel' in df.columns:
                    y = int(pd.to_numeric(df['newlabel'], errors='coerce').dropna().iloc[0])
                else:
                    los_days = float(df['ICULOS'].iloc[-1]) / 24.0 if 'ICULOS' in df.columns else 0.0
                    y = 1 if los_days > float(los_threshold_days) else 0
            else:
                # Mortality: use Survival last as proxy
                y = int(pd.to_numeric(df.get('Survival', pd.Series([0])), errors='coerce').dropna().iloc[-1]) if 'Survival' in df.columns and df['Survival'].notna().any() else 0

            P_list.append({'arr': arr, 'time': time_minutes.reshape(-1, 1), 'extended_static': extended_static})
            y_list.append([y])
        return P_list, np.array(y_list, dtype=int)

    Ptrain, ytrain = build_lists(train_files)
    Pval, yval = build_lists(val_files)
    Ptest, ytest = build_lists(test_files)

    return Ptrain, Pval, Ptest, ytrain, yval, ytest


def get_data_split(base_path, split_path, split_type='random', reverse=False, baseline=True, dataset='P12', predictive_label='mortality'):
    # Optional override: use cached PSV dataset if env flag set
    if dataset == 'P12':
        use_cached_env = os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')
        if use_cached_env:
            cached_dir = os.environ.get('CACHED_PSV_DIR', '/tmp')
            split_pkl_path = os.environ.get('SPLIT_PKL_PATH', 'P12_data_splits/split_1.pkl')
            los_thresh = int(os.environ.get('LOS_THRESHOLD_DAYS', '3'))
            return _load_cached_p12_from_psv(cached_dir, split_pkl_path, predictive_label=predictive_label, los_threshold_days=los_thresh)
    # load data
    if dataset == 'P12':
        Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = ''
    elif dataset == 'P19':
        Pdict_list = np.load(base_path + '/processed_data/PT_dict_list_6.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes_6.npy', allow_pickle=True)
        dataset_prefix = 'P19_'
    elif dataset == 'eICU':
        Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = 'eICU_'
    elif dataset == 'PAM':
        Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = ''  # not applicable
    elif dataset == 'CD':
        Pdict_list = np.load(base_path + '/processed_data_new/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data_new/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = 'CD_'

    show_statistics = False
    if show_statistics:
        idx_under_65 = []
        idx_over_65 = []

        idx_male = []
        idx_female = []

        # variables for statistics
        all_ages = []
        female_count = 0
        male_count = 0
        all_BMI = []

        X_static = np.zeros((len(Pdict_list), len(Pdict_list[0]['extended_static'])))
        for i in range(len(Pdict_list)):
            X_static[i] = Pdict_list[i]['extended_static']
            age, gender_0, gender_1, height, _, _, _, _, weight = X_static[i]
            if age > 0:
                all_ages.append(age)
                if age < 65:
                    idx_under_65.append(i)
                else:
                    idx_over_65.append(i)
            if gender_0 == 1:
                female_count += 1
                idx_female.append(i)
            if gender_1 == 1:
                male_count += 1
                idx_male.append(i)
            if height > 0 and weight > 0:
                all_BMI.append(weight / ((height / 100) ** 2))

        # plot statistics
        plt.hist(all_ages, bins=[i * 10 for i in range(12)])
        plt.xlabel('Years')
        plt.ylabel('# people')
        plt.title('Histogram of patients ages, age known in %d samples.\nMean: %.1f, Std: %.1f, Median: %.1f' %
                  (len(all_ages), np.mean(np.array(all_ages)), np.std(np.array(all_ages)), np.median(np.array(all_ages))))
        plt.show()

        plt.hist(all_BMI, bins=[5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60])
        all_BMI = np.array(all_BMI)
        all_BMI = all_BMI[(all_BMI > 10) & (all_BMI < 65)]
        plt.xlabel('BMI')
        plt.ylabel('# people')
        plt.title('Histogram of patients BMI, height and weight known in %d samples.\nMean: %.1f, Std: %.1f, Median: %.1f' %
                  (len(all_BMI), np.mean(all_BMI), np.std(all_BMI), np.median(all_BMI)))
        plt.show()
        print('\nGender known: %d,  Male count: %d,  Female count: %d\n' % (male_count + female_count, male_count, female_count))

    # np.save('saved/idx_under_65.npy', np.array(idx_under_65), allow_pickle=True)
    # np.save('saved/idx_over_65.npy', np.array(idx_over_65), allow_pickle=True)
    # np.save('saved/idx_male.npy', np.array(idx_male), allow_pickle=True)
    # np.save('saved/idx_female.npy', np.array(idx_female), allow_pickle=True)

    # transformer_path = True
    if baseline==True:
        BL_path = ''
    else:
        BL_path = 'baselines/'

    if split_type == 'random':
        # load random indices from a split
        idx_train, idx_val, idx_test = np.load(base_path + split_path, allow_pickle=True)
    elif split_type == 'age':
        if reverse == False:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_under_65.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_over_65.npy', allow_pickle=True)
        elif reverse == True:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_over_65.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_under_65.npy', allow_pickle=True)

        np.random.shuffle(idx_vt)
        idx_val = idx_vt[:round(len(idx_vt) / 2)]
        idx_test = idx_vt[round(len(idx_vt) / 2):]
    elif split_type == 'gender':
        if reverse == False:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_male.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_female.npy', allow_pickle=True)
        elif reverse == True:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_female.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_male.npy', allow_pickle=True)

        np.random.shuffle(idx_vt)
        idx_val = idx_vt[:round(len(idx_vt) / 2)]
        idx_test = idx_vt[round(len(idx_vt) / 2):]

    # CD dataset: Convert string patient IDs to integer indices
    if dataset == 'CD':
        print("CD dataset: Converting patient IDs to indices")
        print(f"Original idx_train type: {type(idx_train)}, length: {len(idx_train)}")
        if len(idx_train) > 0:
            print(f"Sample idx_train values: {idx_train[:5]}")
        
        # Create mapping from patient ID to index
        patient_id_to_idx = {}
        for i, patient in enumerate(Pdict_list):
            patient_id_to_idx[patient['id']] = i
        
        print(f"Created mapping for {len(patient_id_to_idx)} patients")
        if len(patient_id_to_idx) > 0:
            sample_ids = list(patient_id_to_idx.keys())[:5]
            print(f"Sample patient IDs in data: {sample_ids}")
        
        # Convert string patient IDs to integer indices
        idx_train = [patient_id_to_idx[int(pid)] for pid in idx_train if int(pid) in patient_id_to_idx]
        idx_val = [patient_id_to_idx[int(pid)] for pid in idx_val if int(pid) in patient_id_to_idx]
        idx_test = [patient_id_to_idx[int(pid)] for pid in idx_test if int(pid) in patient_id_to_idx]
        
        print(f"Converted to indices: train={len(idx_train)}, val={len(idx_val)}, test={len(idx_test)}")
        print(f"{len(idx_train)} {len(idx_val)} {len(idx_test)}")
        print(f"{len(idx_train)} {len(idx_val)} {len(idx_test)}")

    # extract train/val/test examples
    Ptrain = Pdict_list[idx_train]
    Pval = Pdict_list[idx_val]
    Ptest = Pdict_list[idx_test]

    # extract mortality labels
    if dataset == 'P12' or dataset == 'P19' or dataset == 'PAM':
        if predictive_label == 'mortality':
            y = arr_outcomes[:, -1].reshape((-1, 1))
        elif predictive_label == 'LoS':  # for P12 only
            y = arr_outcomes[:, 3].reshape((-1, 1))
            y = np.array(list(map(lambda los: 0 if los <= 3 else 1, y)))[..., np.newaxis]
    elif dataset == 'eICU':
        y = arr_outcomes[..., np.newaxis]
    elif dataset == 'CD':
        y = arr_outcomes[..., np.newaxis]  # CD outcomes are already in the correct format
    ytrain = y[idx_train]
    yval = y[idx_val]
    ytest = y[idx_test]

    return Ptrain, Pval, Ptest, ytrain, yval, ytest


def getStats(P_tensor):
    N, T, F = P_tensor.shape
    Pf = P_tensor.transpose((2, 0, 1)).reshape(F, -1)
    mf = np.zeros((F, 1))
    stdf = np.ones((F, 1))
    eps = 1e-7
    for f in range(F):
        vals_f = Pf[f, :]
        # Filter out NaN values and zeros
        vals_f = vals_f[np.logical_and(vals_f > 0, ~np.isnan(vals_f))]
        if len(vals_f) > 0:
            mf[f, 0] = np.mean(vals_f)
            std_val = np.std(vals_f)
            stdf[f, 0] = max(std_val, eps)
        else:
            # If no valid values, set default mean and std
            mf[f, 0] = 0.0
            stdf[f, 0] = eps
    return mf, stdf


def get_features_mean(X_features):
    """
    Calculate means of all time series features (36 features in P12 dataset).

    :param X_features: time series features for all samples in training set
    :return: list of means for all features
    """
    samples, timesteps, features = X_features.shape
    X = np.reshape(X_features, newshape=(samples*timesteps, features)).T
    means = []
    for row in X:
        row = row[row > 0]
        means.append(np.mean(row))
    return means


def mean_imputation(X_features, X_time, mean_features, missing_value_num):
    """
    Fill X_features missing values with mean values of all train samples.

    :param X_features: time series features for all samples
    :param X_time: times, when observations were measured
    :param mean_features: mean values of features from the training set
    :return: X_features, filled with mean values instead of zeros (missing observations)
    """
    time_length = []
    for times in X_time:
        if np.where(times == missing_value_num)[0].size == 0:
            time_length.append(times.shape[0])
        elif np.where(times == missing_value_num)[0][0] == 0:
            time_length.append(np.where(times == missing_value_num)[0][1])
        else:
            time_length.append(np.where(times == missing_value_num)[0][0])

    # check for inconsistency
    for i in range(len(X_features)):
        if np.any(X_features[i, time_length[i]:, :]):
            print('Inconsistency between X_features and X_time: features are measured without time stamp.')

    # impute times series features
    for i, sample in enumerate(X_features):
        X_features_relevant = sample[:time_length[i], :]
        missing_values_idx = np.where(X_features_relevant == missing_value_num)
        for row, col in zip(*missing_values_idx):
            X_features[i, row, col] = mean_features[col]

    return X_features


def forward_imputation(X_features, X_time, missing_value_num):
    """
    Fill X_features missing values with values, which are the same as its last measurement.

    :param X_features: time series features for all samples
    :param X_time: times, when observations were measured
    :return: X_features, filled with last measurements instead of zeros (missing observations)
    """
    time_length = []
    for times in X_time:
        if np.where(times == missing_value_num)[0].size == 0:
            time_length.append(times.shape[0])
        elif np.where(times == missing_value_num)[0][0] == 0:
            time_length.append(np.where(times == missing_value_num)[0][1])
        else:
            time_length.append(np.where(times == missing_value_num)[0][0])

    # impute times series features
    for i, sample in enumerate(X_features):
        for j, ts in enumerate(sample.T):   # note the transposed matrix
            first_observation = True
            current_value = -1
            for k, observation in enumerate(ts[:time_length[i]]):
                if X_features[i, k, j] == missing_value_num and first_observation:
                    continue
                elif X_features[i, k, j] != missing_value_num:
                    current_value = X_features[i, k, j]
                    first_observation = False
                elif X_features[i, k, j] == missing_value_num and not first_observation:
                    X_features[i, k, j] = current_value

    return X_features


def cubic_spline_imputation(X_features, X_time, missing_value_num):
    """
    Fill X_features missing values with cubic spline interpolation.

    :param X_features: time series features for all samples
    :param X_time: times, when observations were measured
    :return: X_features, filled with interpolated values
    """
    from scipy.interpolate import CubicSpline

    time_length = []
    for times in X_time:
        if np.where(times == missing_value_num)[0].size == 0:
            time_length.append(times.shape[0])
        elif np.where(times == missing_value_num)[0][0] == 0:
            time_length.append(np.where(times == missing_value_num)[0][1])
        else:
            time_length.append(np.where(times == missing_value_num)[0][0])

    # impute times series features
    for i, sample in enumerate(X_features):
        for j, ts in enumerate(sample.T):   # note the transposed matrix
            valid_ts = ts[:time_length[i]]
            zero_idx = np.where(valid_ts == missing_value_num)[0]
            non_zero_idx = np.nonzero(valid_ts)[0]
            y = valid_ts[non_zero_idx]

            if len(y) > 1:   # we need at least 2 observations to fit cubic spline
                x = X_time[i, :time_length[i], 0][non_zero_idx]
                x2interpolate = X_time[i, :time_length[i], 0][zero_idx]

                cs = CubicSpline(x, y)
                interpolated_ts = cs(x2interpolate)
                valid_ts[zero_idx] = interpolated_ts

                # set values before first measurement to the value of first measurement
                first_obs_index = non_zero_idx[0]
                valid_ts[:first_obs_index] = np.full(shape=first_obs_index, fill_value=valid_ts[first_obs_index])

                # set values after last measurement to the value of last measurement
                last_obs_index = non_zero_idx[-1]
                valid_ts[last_obs_index:] = np.full(shape=time_length[i] - last_obs_index, fill_value=valid_ts[last_obs_index])

                X_features[i, :time_length[i], j] = valid_ts

    return X_features


def mask_normalize(P_tensor, mf, stdf):
    """ Normalize time series variables. Missing ones are set to zero after normalization. """
    N, T, F = P_tensor.shape
    Pf = P_tensor.transpose((2,0,1)).reshape(F,-1)
    
    # Create mask for valid values (not NaN and greater than 0)
    M = 1*(np.logical_and(P_tensor > 0, ~np.isnan(P_tensor))) + 0*(np.logical_or(P_tensor <= 0, np.isnan(P_tensor)))
    
    # Replace NaN with 0 before normalization
    P_tensor_clean = np.where(np.isnan(P_tensor), 0, P_tensor)
    
    M_3D = M.transpose((2, 0, 1)).reshape(F, -1)
    Pf_clean = P_tensor_clean.transpose((2,0,1)).reshape(F,-1)
    
    for f in range(F):
        Pf_clean[f] = (Pf_clean[f]-mf[f])/(stdf[f]+1e-18)
    Pf_clean = Pf_clean * M_3D
    Pnorm_tensor = Pf_clean.reshape((F,N,T)).transpose((1,2,0))
    Pfinal_tensor = np.concatenate([Pnorm_tensor, M], axis=2)
    return Pfinal_tensor


def getStats_static(P_tensor, dataset='P12'):
    N, S = P_tensor.shape
    Ps = P_tensor.transpose((1, 0))
    ms = np.zeros((S, 1))
    ss = np.ones((S, 1))

    if dataset == 'P12':
        # ['Age' 'Gender=0' 'Gender=1' 'Height' 'ICUType=1' 'ICUType=2' 'ICUType=3' 'ICUType=4' 'Weight']
        bool_categorical = [0, 1, 1, 0, 1, 1, 1, 1, 0]
    elif dataset == 'P19':
        # ['Age' 'Gender' 'Unit1' 'Unit2' 'HospAdmTime' 'ICULOS']
        bool_categorical = [0, 1, 0, 0, 0, 0]
    elif dataset == 'eICU':
        # ['apacheadmissiondx' 'ethnicity' 'gender' 'admissionheight' 'admissionweight'] -> 399 dimensions
        bool_categorical = [1] * 397 + [0] * 2
    elif dataset == 'CD':
        # ['sex'] -> 1 dimension
        bool_categorical = [1]  # sex is categorical

    for s in range(S):
        if bool_categorical[s] == 0:  # if not categorical
            vals_s = Ps[s, :]
            # Filter out NaN values and zeros
            vals_s = vals_s[np.logical_and(vals_s > 0, ~np.isnan(vals_s))]
            if len(vals_s) > 0:
                ms[s, 0] = np.mean(vals_s)
                ss[s, 0] = np.std(vals_s)
            else:
                # If no valid values, set default mean and std
                ms[s, 0] = 0.0
                ss[s, 0] = 1.0
    return ms, ss


def mask_normalize_static(P_tensor, ms, ss):
    N, S = P_tensor.shape
    Ps = P_tensor.transpose((1, 0))

    # Replace NaN with 0 before normalization
    Ps_clean = np.where(np.isnan(Ps), 0, Ps)

    # input normalization
    for s in range(S):
        Ps_clean[s] = (Ps_clean[s] - ms[s]) / (ss[s] + 1e-18)

    # set missing values to zero after normalization
    for s in range(S):
        idx_missing = np.where(np.logical_or(Ps_clean[s, :] <= 0, np.isnan(Ps_clean[s, :])))
        Ps_clean[s, idx_missing] = 0

    # reshape back
    Pnorm_tensor = Ps_clean.reshape((S, N)).transpose((1, 0))
    return Pnorm_tensor


def tensorize_normalize(P, y, mf, stdf, ms, ss):
    # Determine feature dimension from first sample
    F = P[0]['arr'].shape[1]
    D = len(P[0]['extended_static'])

    # Use the maximum time length across patients to build fixed-size tensors
    T_max = max(len(sample['arr']) for sample in P)

    P_tensor = np.zeros((len(P), T_max, F))
    P_time = np.zeros((len(P), T_max, 1))
    P_static_tensor = np.zeros((len(P), D))
    for i in range(len(P)):
        arr_i = P[i]['arr']
        tim_i = P[i]['time']
        t_i = min(len(arr_i), T_max)
        if t_i > 0:
            P_tensor[i, :t_i, :] = arr_i[:t_i]
            P_time[i, :t_i, 0] = tim_i[:t_i, 0]
        P_static_tensor[i] = P[i]['extended_static']
    P_tensor = mask_normalize(P_tensor, mf, stdf)
    P_tensor = torch.Tensor(P_tensor)

    P_time = torch.Tensor(P_time) / 60.0  # convert mins to hours
    P_static_tensor = mask_normalize_static(P_static_tensor, ms, ss)
    P_static_tensor = torch.Tensor(P_static_tensor)

    y_tensor = y
    y_tensor = torch.Tensor(y_tensor[:, 0]).type(torch.LongTensor)
    return P_tensor, P_static_tensor, P_time, y_tensor


def tensorize_normalize_other(P, y, mf, stdf):
    T, F = P[0].shape

    P_time = np.zeros((len(P), T, 1))
    for i in range(len(P)):
        tim = torch.linspace(0, T, T).reshape(-1, 1)
        P_time[i] = tim
    P_tensor = mask_normalize(P, mf, stdf)
    P_tensor = torch.Tensor(P_tensor)

    P_time = torch.Tensor(P_time) / 60.0

    y_tensor = y
    y_tensor = torch.Tensor(y_tensor[:, 0]).type(torch.LongTensor)
    return P_tensor, None, P_time, y_tensor


def masked_softmax(A, epsilon=0.000000001):
    A_max = torch.max(A, dim=1, keepdim=True)[0]
    A_exp = torch.exp(A - A_max)
    A_exp = A_exp * (A != 0).float()
    A_softmax = A_exp / (torch.sum(A_exp, dim=0, keepdim=True) + epsilon)
    return A_softmax


def random_sample(idx_0, idx_1, B, replace=False):
    """ Returns a balanced sample of tensors by randomly sampling without replacement. """
    idx0_batch = np.random.choice(idx_0, size=int(B / 2), replace=replace)
    idx1_batch = np.random.choice(idx_1, size=int(B / 2), replace=replace)
    idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
    return idx


def random_sample_8(ytrain, B, replace=False):
    """ Returns a balanced sample of tensors by randomly sampling without replacement. """
    idx0_batch = np.random.choice(np.where(ytrain == 0)[0], size=int(B / 8), replace=replace)
    idx1_batch = np.random.choice(np.where(ytrain == 1)[0], size=int(B / 8), replace=replace)
    idx2_batch = np.random.choice(np.where(ytrain == 2)[0], size=int(B / 8), replace=replace)
    idx3_batch = np.random.choice(np.where(ytrain == 3)[0], size=int(B / 8), replace=replace)
    idx4_batch = np.random.choice(np.where(ytrain == 4)[0], size=int(B / 8), replace=replace)
    idx5_batch = np.random.choice(np.where(ytrain == 5)[0], size=int(B / 8), replace=replace)
    idx6_batch = np.random.choice(np.where(ytrain == 6)[0], size=int(B / 8), replace=replace)
    idx7_batch = np.random.choice(np.where(ytrain == 7)[0], size=int(B / 8), replace=replace)
    idx = np.concatenate([idx0_batch, idx1_batch, idx2_batch, idx3_batch, idx4_batch, idx5_batch, idx6_batch, idx7_batch], axis=0)
    return idx


def evaluate(model, P_tensor, P_time_tensor, P_static_tensor, batch_size=100, n_classes=2, static=1):
    model.eval()
    # Get device from model
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)
    P_time_tensor = P_time_tensor.to(device)
    if static is None:
        Pstatic = None
    else:
        P_static_tensor = P_static_tensor.to(device)
        N, Fs = P_static_tensor.shape

    T, N, Ff = P_tensor.shape

    n_batches, rem = N // batch_size, N % batch_size

    out = torch.zeros(N, n_classes)
    start = 0
    for i in range(n_batches):
        P = P_tensor[:, start:start + batch_size, :]
        Ptime = P_time_tensor[:, start:start + batch_size]
        if P_static_tensor is not None:
            Pstatic = P_static_tensor[start:start + batch_size]
        lengths = torch.sum(Ptime > 0, dim=0)
        out[start:start + batch_size] = model.forward(P, Pstatic, Ptime, lengths).detach().cpu()
        start += batch_size
    if rem > 0:
        P = P_tensor[:, start:start + rem, :]
        Ptime = P_time_tensor[:, start:start + rem]
        if P_static_tensor is not None:
            Pstatic = P_static_tensor[start:start + batch_size]
        lengths = torch.sum(Ptime > 0, dim=0)
        out[start:start + rem] = model.forward(P, Pstatic, Ptime, lengths).detach().cpu()
    return out


def evaluate_standard(model, P_tensor, P_time_tensor, P_static_tensor, batch_size=100, n_classes=2, static=1):
    # Get device from model
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)
    P_time_tensor = P_time_tensor.to(device)
    if static is None:
        P_static_tensor = None
    else:
        P_static_tensor = P_static_tensor.to(device)

    lengths = torch.sum(P_time_tensor > 0, dim=0)
    out = model.forward(P_tensor, P_static_tensor, P_time_tensor, lengths)
    return out


def evaluate_MTGNN(model, P_tensor, P_static_tensor, static=1):
    # Get device from model
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)

    P_tensor = torch.permute(P_tensor, (1, 0, 2))
    P_tensor = torch.unsqueeze(P_tensor, dim=1)
    P_tensor = P_tensor.transpose(2, 3)

    if static is None:
        P_static_tensor = None
    else:
        P_static_tensor = P_static_tensor.to(device)

    # Ensure sequence length matches model's expected seq_length by padding or truncating along time dim
    expected_seq_len = getattr(model, 'seq_length', None)
    if expected_seq_len is not None:
        current_seq_len = P_tensor.size(-1)
        if current_seq_len < expected_seq_len:
            pad_len = expected_seq_len - current_seq_len
            pad = torch.zeros(
                P_tensor.size(0), P_tensor.size(1), P_tensor.size(2), pad_len,
                device=P_tensor.device, dtype=P_tensor.dtype
            )
            P_tensor = torch.cat([P_tensor, pad], dim=-1)
        elif current_seq_len > expected_seq_len:
            P_tensor = P_tensor[..., :expected_seq_len]

    out = model.forward(P_tensor, P_static_tensor)
    return out


def evaluate_DGM2(model, P_tensor, P_static_tensor, static=1):
    # suppose P_time is equal in all patients
    P_time = torch.arange(P_tensor.size()[0])

    # Get device from model
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)
    P_tensor = torch.permute(P_tensor, (1, 0, 2))

    if static is None:
        P_static_tensor = None
    else:
        P_static_tensor = P_static_tensor.to(device)

    out = model.forward(P_tensor, P_time, P_static_tensor)
    return out


def linspace_vector(start, end, n_points):
    # start is either one value or a vector
    size = np.prod(start.size())

    assert(start.size() == end.size())
    if size == 1:
        # start and end are 1d-tensors
        res = torch.linspace(start, end, n_points)
    else:
        # start and end are vectors
        res = torch.Tensor()
        for i in range(0, start.size(0)):
            res = torch.cat((res,
                torch.linspace(start[i], end[i], n_points)),0)
        res = torch.t(res.reshape(start.size(0), n_points))
    return res


# Adam using warmup
class NoamOpt:
    "Optim wrapper that implements rate."

    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))

    def zero_grad(self):
        self.optimizer.zero_grad()
