"""
* Apply time based decay
* Normalize the data
* Filter negative time

change param for normalized time -> absolute max vs max wrt to each HADM_IDs
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle

# set seed
np.random.seed(42)

dt = pd.read_csv('MIMIC_GIB_24h_with_severe_liver.csv')
dt.head()
dt.groupby('HADM_ID').size().mean()
grouped = dt.groupby('HADM_ID')
# print(grouped.head())
grouped_sorted = grouped.apply(lambda x: x.sort_values(["TIME_FROM_ADM"], ascending = True)).reset_index(drop=True)
# print(grouped_sorted.head())
mean_time = grouped_sorted['TIME_FROM_ADM'].diff().mean()
# print(mean_time)
grouped_sorted['TIME_FROM_ADM'].diff().drop(0).mean() # min time


# time based decay between infusions
def apply_time_based_decay(df, bloodprod_column='bloodprod', time_column='TIME_FROM_ADM', decay_rate=1):
    # decay function normalized to 1 at x=0
    gaussian_decay = lambda x: np.exp(-((x/decay_rate)**2)/2)

    last_infusion_time = -np.inf  
    for i, row in df.iterrows():
        if row[bloodprod_column] == 1:
            # Reset the last infusion time
            last_infusion_time = row[time_column]
        else:
            time_since_last_infusion = row[time_column] - last_infusion_time
            if time_since_last_infusion < 1:
                decay = gaussian_decay(time_since_last_infusion) 
                df.at[i, bloodprod_column] = decay if decay > 0.01 else 0

    return df

dt = apply_time_based_decay(dt, bloodprod_column='bloodprod')
dt = apply_time_based_decay(dt, bloodprod_column='pressor')

dt.columns
time_cols = ['1','2','3'] # 1 is HR, 2 is systolic BP, 3 is diastolic BP
class_cols = ['SEX_ID', 'AGE_AT_ADM',
       'prbc_outcome', 'pressor', 'bloodprod', 'severe_liver','HADM_ID']
treatment_conditional = ['prbc_outcome', 'pressor', 'bloodprod','severe_liver']


traindt = dt.dropna(subset=['1','2','3'])[['1','2','3','TIME_FROM_ADM','severe_liver','HADM_ID','prbc_outcome','pressor','bloodprod']]
# MAP = DP + 1/3(SP – DP) 
traindt["MAP"] = traindt['3'] + 1/3*(traindt['2'] - traindt['3'])


# normalize traindt
traindt['1'] = (traindt['1'] - traindt['1'].mean())/traindt['1'].std()
traindt['2'] = (traindt['2'] - traindt['2'].mean())/traindt['2'].std()
traindt['3'] = (traindt['3'] - traindt['3'].mean())/traindt['3'].std()
traindt['MAP'] = (traindt['MAP'] - traindt['MAP'].mean())/traindt['MAP'].std()
# traindt['TIME_FROM_ADM'] = traindt['TIME_FROM_ADM']/traindt['TIME_FROM_ADM'].max()

# filter negative time
traindt = traindt[traindt['TIME_FROM_ADM']>=0]
abs_time_max = traindt['TIME_FROM_ADM'].max()

def normalize_time(group, time_max = abs_time_max):
    if time_max is None:
        time_min = group['TIME_FROM_ADM'].min()
        time_max = group['TIME_FROM_ADM'].max()
    else:
        time_min = 0
        time_max = time_max

    if time_max == time_min:
        group['time_normalized'] = 0.5  # or any default value, or keep the original time
    else:
        group['time_normalized'] = (group['TIME_FROM_ADM'] - time_min) / (time_max - time_min)
    return group


df_normalized = traindt.groupby('HADM_ID', group_keys=False).apply(normalize_time)
# df_normalized.head()

# grouped = df_normalized.groupby('HADM_ID')
# # print(grouped.head())
# grouped_sorted = grouped.apply(lambda x: x.sort_values(["TIME_FROM_ADM"], ascending = True)).reset_index(drop=True)



# randomly select HAMD_ID for train, val and test
test_frac = 0.1
val_frac = 0.1

HADM_IDs = traindt['HADM_ID'].unique()
np.random.shuffle(HADM_IDs)
test_ids = HADM_IDs[:int(len(HADM_IDs)*test_frac)]
val_ids = HADM_IDs[int(len(HADM_IDs)*test_frac):int(len(HADM_IDs)*(test_frac+val_frac))]
train_ids = HADM_IDs[int(len(HADM_IDs)*(test_frac+val_frac)):]

print(len(train_ids), len(val_ids), len(test_ids)) # 2082 260 260

# set as new data
train = df_normalized[df_normalized['HADM_ID'].isin(train_ids)].reset_index(drop=True)
val = df_normalized[df_normalized['HADM_ID'].isin(val_ids)].reset_index(drop=True)
test = df_normalized[df_normalized['HADM_ID'].isin(test_ids)].reset_index(drop=True)

df_D = {'train': train, 'val': val, 'test': test}

# save the train, val and test data
save_path = "/home/XXXX-1/x/XXXX-2/scratch/XXXX-3/mimic/"

pickle.dump(df_D, open(save_path + "MIMIC_GIB_24h_with_severe_liver_normalized_time.pkl", "wb"))


