# load data same as train.py

import pandas as pd
import numpy as np
import random
import torch
from tqdm import trange

import sys


pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#load data
SUBJECT = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/SUBJECT.csv', index_col=0)
ICU = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/ICU.csv', index_col=0).query('age >= 18 and age <= 95')
DAILY = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/DAILY_new.csv', index_col=0,
                    parse_dates=['day_bucket_starts','day_bucket_ends'])


COHORT = ICU.merge(DAILY, on = ['pat_enc_csn_id', 'patient_ir_id'])
COHORT = COHORT.merge(SUBJECT, on = 'patient_ir_id', how = 'left' )
COHORT.day_bucket_starts = COHORT.day_bucket_starts.dt.floor('D')
COHORT = COHORT[COHORT.vent_flag == 1]
los = COHORT.groupby('pat_enc_csn_id', as_index = False)['icu_day'].nunique()

COHORT = COHORT[COHORT.pat_enc_csn_id.isin(los.query('icu_day <= 30').pat_enc_csn_id.unique().tolist())]

# train_test_split
TrainPATIENT = COHORT[(COHORT.day_bucket_starts >= '2020-03-15') &  (COHORT.day_bucket_starts <= '2021-07-14')].pat_enc_csn_id.unique().tolist()
TrainCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(TrainPATIENT)]
ValPATIENT = COHORT[(COHORT.day_bucket_starts >= '2021-07-15') &  (COHORT.day_bucket_starts <= '2021-10-14')].pat_enc_csn_id.unique().tolist()
ValCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(ValPATIENT) & (~COHORT.pat_enc_csn_id.isin(TrainPATIENT))]
TestPATIENT = COHORT[(COHORT.day_bucket_starts >= '2021-10-15') &  (COHORT.day_bucket_starts <= '2023-01-15')].pat_enc_csn_id.unique().tolist()
TestCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(TestPATIENT) & (~COHORT.pat_enc_csn_id.isin(TrainPATIENT)) & (~COHORT.pat_enc_csn_id.isin(ValPATIENT))]

len(TrainPATIENT), len(ValPATIENT), len(TestPATIENT)

from sklearn import preprocessing


repCols = ['p_f_ratio_points', 'platelet_points',
       'bilirubin_points', 'htn_points', 'gcs_points', 'renal_points',
        'PULSE', 'SP02', 'RESPIRATIONS', 'BMI', 'BP_SYSTOLIC',
       'BP_DIASTOLIC', 'TEMPERATURE', 'age', 'COVID', 'ami', 'chf',
       'pvd', 'cevd', 'dementia', 'copd', 'rheumd', 'pud', 'mld', 'diab',
       'diabwc', 'hp', 'rend', 'canc', 'msld', 'metacanc', 'aids', 'isFemale',
       'isAsian', 'isBlack', 'isHispanic', 'isNative', 'isWhite']#, 'vent_flag']

SOFACols = ['p_f_ratio_points','platelet_points','bilirubin_points','htn_points','gcs_points','renal_points']

repColsNext = ['next_'+c for c in repCols]

SOFAColsNext = ['next_'+c for c in SOFACols]

TrainCOHORT = TrainCOHORT[TrainCOHORT.vent_flag == 1]
TestCOHORT = TestCOHORT[TestCOHORT.vent_flag == 1]
ValCOHORT = ValCOHORT[ValCOHORT.vent_flag == 1]
deaths = TestCOHORT[TestCOHORT.death == 1]['pat_enc_csn_id'].unique().tolist()
#

# LOTTERY PROTOCOL

# random
RNDMdeathTally = {}
RNDMallocations = {}


for cap in trange(0,85):
    
    dailyVENT = TestCOHORT.copy(deep=True)
    
    dailyVENT['whoWasOn'] = 1 #off
    
    random.seed(cap+202305102)
    
#     firstDaySampled = random.sample(firstDay, min(cap+24, len(firstDay)))

#     firstDayExcluded = [p for p in firstDay if p not in firstDaySampled]

#     dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(firstDayExcluded)]
    
    whoWasOn = []#random.sample(firstDaySampled, min(cap, len(firstDay)))
    
    #dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 # on
    
    
    earlyDeaths = []
    allocation = []

    for t in pd.date_range(start='2021-10-15', end='2023-01-15'):
        needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
        if (len(needOfTheDay) <= cap):
            needOfTheDay['onoff'] = 1 
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            continue
        else:
            
            needOfTheDay['priorityOfRandom'] = np.random.rand(len(needOfTheDay)).tolist()
            #needOfTheDay.loc[needOfTheDay.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            needOfTheDay = needOfTheDay.sort_values( ['whoWasOn', 'priorityOfRandom'])
            needOfTheDay['onoff'] = [1]*cap + [0]*(len(needOfTheDay)-cap)
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()[:cap]
            
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            excludePAs = needOfTheDay.pat_enc_csn_id.tolist()[cap:]
            earlyDeaths += excludePAs
            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
    
    RNDMdeathTally[cap] = list(set(earlyDeaths))
    RNDMallocations[cap] = pd.concat(allocation)

        

RNDMPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
RNDMPolicy['capacity'] = list(RNDMdeathTally.keys())
RNDMPolicy['deaths'] = [len(set(d+deaths)) for d in RNDMdeathTally.values()]



# YF PROTOCOL

# random
YFdeathTally = {}
YFallocations = {}


for cap in trange(0,85):
    
    dailyVENT = TestCOHORT.copy(deep=True)
    
    dailyVENT['whoWasOn'] = 1 #off
    
    random.seed(cap+20230510)
    
#     firstDaySampled = random.sample(firstDay, min(cap+24, len(firstDay)))

#     firstDayExcluded = [p for p in firstDay if p not in firstDaySampled]

#     dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(firstDayExcluded)]
    
    whoWasOn = []#random.sample(firstDaySampled, min(cap, len(firstDay))) # all on
    
    #dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 # on
    
    
    earlyDeaths = []
    allocation = []

    for t in pd.date_range(start='2021-10-15', end='2023-01-15'):
        needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
        if (len(needOfTheDay) <= cap):
            needOfTheDay['onoff'] = 1 
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()
            dailyVENT['whoWasOn'] = 1 #refresh data, turn off for all
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 #on
            
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            continue
        else:
            
            #needOfTheDay['priorityOfRandom'] = np.random.rand(len(needOfTheDay)).tolist()
            #needOfTheDay.loc[needOfTheDay.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            needOfTheDay = needOfTheDay.sort_values( ['whoWasOn', 'age'])
            needOfTheDay['onoff'] = [1]*cap + [0]*(len(needOfTheDay)-cap)
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()[:cap]
            
            dailyVENT['whoWasOn'] = 1 #refresh data, turn off for all
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0  # on
            
            excludePAs = needOfTheDay.pat_enc_csn_id.tolist()[cap:]
            earlyDeaths += excludePAs
            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
    
    YFdeathTally[cap] = list(set(earlyDeaths))
    YFallocations[cap] = pd.concat(allocation)

        

YFPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
YFPolicy['capacity'] = list(YFdeathTally.keys())
YFPolicy['deaths'] = [len(set(d+deaths)) for d in YFdeathTally.values()]

# SOFA PROTOCOL

def SOFAriage(sofascore):
    if sofascore > 11:
        return 2 # low
    if sofascore in [8,9,10,11]:
        return 1 # medium
    else:
        return 0 #high

TestCOHORT['SOFA'] = TestCOHORT[SOFACols].sum(axis=1).apply(SOFAriage)

# random
SOFAdeathTally = {}
SOFAallocations = {}


for cap in trange(0,85):
    
    dailyVENT = TestCOHORT.copy(deep=True)
    
    dailyVENT['whoWasOn'] = 1 #off
    
    random.seed(cap+20230510)
    
#     firstDaySampled = random.sample(firstDay, min(cap+24, len(firstDay)))

#     firstDayExcluded = [p for p in firstDay if p not in firstDaySampled]

#     dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(firstDayExcluded)]
    
    whoWasOn = []#random.sample(firstDaySampled, min(cap, len(firstDay)))
    
    #dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 # on
    
    
    earlyDeaths = []
    allocation = []

    for t in pd.date_range(start='2021-10-15', end='2023-01-15'):
        needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
        if (len(needOfTheDay) <= cap):
            needOfTheDay['onoff'] = 1 
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            continue
        else:
            
            #needOfTheDay['priorityOfRandom'] = np.random.rand(len(needOfTheDay)).tolist()
            #needOfTheDay.loc[needOfTheDay.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            needOfTheDay = needOfTheDay.sort_values( ['whoWasOn', 'SOFA'])
            needOfTheDay['onoff'] = [1]*cap + [0]*(len(needOfTheDay)-cap)
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()[:cap]
            
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            excludePAs = needOfTheDay.pat_enc_csn_id.tolist()[cap:]
            earlyDeaths += excludePAs
            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
    
    SOFAdeathTally[cap] = list(set(earlyDeaths))
    SOFAallocations[cap] = pd.concat(allocation)

        

SOFAPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
SOFAPolicy['capacity'] = list(SOFAdeathTally.keys())
SOFAPolicy['deaths'] = [len(set(d+deaths)) for d in SOFAdeathTally.values()]


# MP PROTOCOL

def MPriage(MPscore):
    if MPscore > 14:
        return 4
    if MPscore in [12, 13,14]:
        return 3
    if MPscore in [9,10,11]:
        return 2
    else:
        return 1
    
CCICols = ['ami', 'chf',
       'pvd', 'cevd', 'dementia', 'copd', 'rheumd', 'pud', 'mld', 'diab',
       'diabwc', 'hp', 'rend', 'canc', 'msld', 'metacanc', 'aids']

TestCOHORT['MP'] = TestCOHORT[SOFACols].sum(axis=1).apply(MPriage)+ (TestCOHORT[CCICols].sum(axis=1)>=8).astype(int)*3
age_groups = [0, 50, 70, 85, 300]
age_labels = [0,1,2,3]
TestCOHORT['ageG'] = pd.cut(TestCOHORT["age"], bins=age_groups, labels=age_labels)
# random
MPdeathTally = {}
MPallocations = {}


for cap in trange(0,85):
    
    dailyVENT = TestCOHORT.copy(deep=True)
    
    dailyVENT['whoWasOn'] = 1 #off
    
    random.seed(cap+20230510)
    
#     firstDaySampled = random.sample(firstDay, min(cap+24, len(firstDay)))

#     firstDayExcluded = [p for p in firstDay if p not in firstDaySampled]

#     dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(firstDayExcluded)]
    
    whoWasOn = []#random.sample(firstDaySampled, min(cap, len(firstDay)))
    
    #dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 # on
    
    
    earlyDeaths = []
    allocation = []

    for t in pd.date_range(start='2021-10-15', end='2023-01-15'):
        needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
        if (len(needOfTheDay) <= cap):
            needOfTheDay['onoff'] = 1 
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            continue
        else:
            
            #needOfTheDay['priorityOfRandom'] = np.random.rand(len(needOfTheDay)).tolist()
            #needOfTheDay.loc[needOfTheDay.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            needOfTheDay = needOfTheDay.sort_values( ['whoWasOn', 'MP', 'ageG'])
            needOfTheDay['onoff'] = [1]*cap + [0]*(len(needOfTheDay)-cap)
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'onoff', 'day_bucket_starts', 'whoWasOn']])
            
            whoWasOn = needOfTheDay.pat_enc_csn_id.tolist()[:cap]
            
            dailyVENT['whoWasOn'] = 1
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0
            
            excludePAs = needOfTheDay.pat_enc_csn_id.tolist()[cap:]
            earlyDeaths += excludePAs
            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
    
    MPdeathTally[cap] = list(set(earlyDeaths))
    MPallocations[cap] = pd.concat(allocation)

        

MPPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
MPPolicy['capacity'] = list(MPdeathTally.keys())
MPPolicy['deaths'] = [len(set(d+deaths)) for d in MPdeathTally.values()]


# TDQN-fair

TDQNFAIRres = {}


for i in trange(1,85):
    res = pd.read_csv('/share/fsmresfiles/ylo7832/covidVent/KL/{}/res.csv'.format(i))
    try:
        TDQNFAIRres[i] = res.iloc[10:].sort_values('test_KL').iloc[:1]['test'].min()
    except KeyError:
         TDQNFAIRres[i] = res.iloc[10:].sort_values('val').iloc[:1]['test'].min()
        
    
TDQNFAIRPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
TDQNFAIRPolicy['capacity'] = list(TDQNFAIRres.keys())
TDQNFAIRPolicy['deaths'] = list(TDQNFAIRres.values())
TDQNFAIRPolicy['deaths'] = list(TDQNFAIRres.values())
TDQNFAIRPolicy.loc[len(TDQNFAIRPolicy)] = [0, TestCOHORT.pat_enc_csn_id.nunique()]
TDQNFAIRPolicy = TDQNFAIRPolicy.sort_values('capacity')

from utils import BIG_Q_ATTN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def value(Q, CAP, state, mask, who_was_on):
    with torch.no_grad():
        action_heads = Q(state, mask)
        #mask the invalid actions
        action_values = torch.where(mask == 0, -np.inf,
                                   (action_heads[:,:,1] - action_heads[:,:,0]).double())

        action_values = torch.where(who_was_on == 1, np.inf,
                                    (action_values).double())

#                 # sort value
        action_sort_pos = CAP + 23 - torch.argsort(torch.argsort(action_values, dim = -1, descending = False), 
                                        dim = -1, descending = False)
        # select 
        action_capactiy = torch.clip(mask.sum(axis = 1), max = CAP)\
        .unsqueeze(0).transpose(0,1).repeat(1,CAP+24)

        # bool output for next actions
        actions = (action_sort_pos < action_capactiy)

    return actions.detach().cpu().numpy()


# TDDQN-fair
TDQNFAIRdeathTally = {}
TDQNFAIRallocations = {}

for cap in trange(1,85):
    res = pd.read_csv('/share/fsmresfiles/ylo7832/covidVent/KL/{}/res.csv'.format(cap))

    ckpt = res.iloc[10:].sort_values('val_KL').index[0]
    Q = BIG_Q_ATTN(42,2).to(device)
    Q.load_state_dict(torch.load('/share/fsmresfiles/ylo7832/covidVent/KL/{}/ckpt/{}.pth'.format(cap, ckpt)))
            


            
        

    dailyVENT = TestCOHORT.copy(deep=True)
    
    dailyVENT['whoWasOn'] = 1 #off
    
    random.seed(cap+20230510)

    
    whoWasOn = []#random.sample(firstDaySampled, min(cap, len(firstDay)))
    
    
    
    earlyDeaths = []
    allocation = []
    races = [[0,0,0,0]]
    racesYes = [[0,0,0,0]]

    for t in pd.date_range(start='2021-10-15', end = '2023-01-15'):
        needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
        
        raceFeat = np.tile(np.divide(racesYes[-1],races[-1], where=np.array(races[-1])!=0), (len(needOfTheDay),1))
             
        with torch.no_grad():
            state = torch.nn.functional.pad(torch.from_numpy(np.hstack((needOfTheDay[repCols].values,
                                                                           raceFeat)).astype(np.float32)), 
                                                 pad = (0,0,0,cap+24 - len(needOfTheDay)), 
                                            mode = 'constant', value = 0).unsqueeze(0).to(device)
            mask = torch.from_numpy(np.array([1]*len(needOfTheDay) + 
                                             [0]*(cap+24 - len(needOfTheDay)))).unsqueeze(0).to(device)
            
            nextwhowason = torch.from_numpy(np.array(list(1-needOfTheDay['whoWasOn'].values)+\
                                                     [0]*(cap+24 - len(needOfTheDay)))).unsqueeze(0).to(device)


            needOfTheDay['onoff'] = value(Q, cap, state,mask, nextwhowason)[0][:len(needOfTheDay)]
            
            previous_races = races[-1]
            current_races = needOfTheDay.query('whoWasOn == 1')[['isAsian', 'isBlack',
                                                        'isHispanic', 'isWhite']].sum().tolist()
            races.append([x + y for x, y in zip(previous_races, current_races)])

            previous_racesYes = racesYes[-1]
            current_racesYes = needOfTheDay.query('onoff == True and whoWasOn == 1')[['isAsian', 'isBlack', 
                                                      'isHispanic', 'isWhite']].sum().tolist()
            racesYes.append([x + y for x, y in zip(previous_racesYes, current_racesYes)])
            
            
            allocation.append(needOfTheDay[['pat_enc_csn_id', 'whoWasOn', 'onoff', 'day_bucket_starts']])
            
            excludePAs = needOfTheDay[needOfTheDay.onoff == False].pat_enc_csn_id.unique().tolist()
            
            whoWasOn = needOfTheDay[needOfTheDay.onoff == True].pat_enc_csn_id.unique().tolist()

            earlyDeaths += excludePAs

            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
            
            dailyVENT['whoWasOn'] = 1 # off
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 0 # on
    
    TDQNFAIRdeathTally[cap] = list(set(earlyDeaths))
    TDQNFAIRallocations[cap] = pd.concat(allocation)
            
            
    
    

TDQNFAIRPolicy = pd.DataFrame(columns = ['capacity', 'deaths'])
TDQNFAIRPolicy['capacity'] = list(TDQNFAIRdeathTally.keys())
TDQNFAIRPolicy['deaths'] = [len(set(d+deaths)) for d in TDQNFAIRdeathTally.values()]
TDQNFAIRPolicy.loc[len(TDQNFAIRPolicy)] = [0, TestCOHORT.pat_enc_csn_id.nunique()]
TDQNFAIRPolicy = TDQNFAIRPolicy.sort_values('capacity')


# save the results

RNDMPolicy.to_csv('/share/fsmresfiles/ylo7832/RNDM.csv')
YFPolicy.to_csv('/share/fsmresfiles/ylo7832/YF.csv')
SOFAPolicy.to_csv('/share/fsmresfiles/ylo7832/SOFA.csv')
MPPolicy.to_csv('/share/fsmresfiles/ylo7832/MP.csv')
TDQNFAIRPolicy.to_csv('/share/fsmresfiles/ylo7832/FAIR.csv')


import pickle

with open('/share/fsmresfiles/ylo7832/covidVent/FAIR.pickle', 'wb') as file:
    pickle.dump(TDQNFAIRallocations, file)


with open('/share/fsmresfiles/ylo7832/covidVent/RNDM.pickle', 'wb') as file:
    pickle.dump(RNDMallocations, file)


with open('/share/fsmresfiles/ylo7832/covidVent/YF.pickle', 'wb') as file:
    pickle.dump(YFallocations, file)


with open('/share/fsmresfiles/ylo7832/covidVent/SOFA.pickle', 'wb') as file:
    pickle.dump(SOFAallocations, file)

with open('/share/fsmresfiles/ylo7832/covidVent/MP.pickle', 'wb') as file:
    pickle.dump(MPallocations, file)
