import pandas as pd
import numpy as np
import random
import torch
from tqdm import trange
import scipy.stats as stats

import sys

CAP = int(sys.argv[1])


torch.manual_seed(CAP*2+210024)

print('CAP@ '+str(CAP))

pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#load data
SUBJECT = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/SUBJECT.csv', index_col=0)
ICU = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/ICU.csv', index_col=0).query('age >= 18 and age <= 95')
DAILY = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/DAILY_new.csv', index_col=0,
                    parse_dates=['day_bucket_starts','day_bucket_ends'])


COHORT = ICU.merge(DAILY, on = ['pat_enc_csn_id', 'patient_ir_id'])
COHORT = COHORT.merge(SUBJECT, on = 'patient_ir_id', how = 'left' )
COHORT.day_bucket_starts = COHORT.day_bucket_starts.dt.floor('D')
COHORT = COHORT[COHORT.vent_flag == 1]
los = COHORT.groupby('pat_enc_csn_id', as_index = False)['icu_day'].nunique()

COHORT = COHORT[COHORT.pat_enc_csn_id.isin(los.query('icu_day <= 30').pat_enc_csn_id.unique().tolist())]

# train_test_split
TrainPATIENT = COHORT[(COHORT.day_bucket_starts >= '2020-03-15') &  (COHORT.day_bucket_starts <= '2021-07-14')].pat_enc_csn_id.unique().tolist()
TrainCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(TrainPATIENT)]
ValPATIENT = COHORT[(COHORT.day_bucket_starts >= '2021-07-15') &  (COHORT.day_bucket_starts <= '2021-10-14')].pat_enc_csn_id.unique().tolist()
ValCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(ValPATIENT) & (~COHORT.pat_enc_csn_id.isin(TrainPATIENT))]
TestPATIENT = COHORT[(COHORT.day_bucket_starts >= '2021-10-15') &  (COHORT.day_bucket_starts <= '2023-01-15')].pat_enc_csn_id.unique().tolist()
TestCOHORT = COHORT[COHORT.pat_enc_csn_id.isin(TestPATIENT) & (~COHORT.pat_enc_csn_id.isin(TrainPATIENT)) & (~COHORT.pat_enc_csn_id.isin(ValPATIENT))]

len(TrainPATIENT), len(ValPATIENT), len(TestPATIENT)

from sklearn import preprocessing


repCols = ['p_f_ratio_points', 'platelet_points',
       'bilirubin_points', 'htn_points', 'gcs_points', 'renal_points',
        'PULSE', 'SP02', 'RESPIRATIONS', 'BMI', 'BP_SYSTOLIC',
       'BP_DIASTOLIC', 'TEMPERATURE', 'age', 'COVID', 'ami', 'chf',
       'pvd', 'cevd', 'dementia', 'copd', 'rheumd', 'pud', 'mld', 'diab',
       'diabwc', 'hp', 'rend', 'canc', 'msld', 'metacanc', 'aids', 'isFemale',
       'isAsian', 'isBlack', 'isHispanic', 'isNative', 'isWhite']#, 'vent_flag']

SOFACols = ['p_f_ratio_points','platelet_points','bilirubin_points','htn_points','gcs_points','renal_points']

repColsNext = ['next_'+c for c in repCols]

SOFAColsNext = ['next_'+c for c in SOFACols]

train_rep = TrainCOHORT[repCols]
test_rep = TestCOHORT[repCols]
val_rep = ValCOHORT[repCols]

# standardize
min_max_scaler = preprocessing.MinMaxScaler()
min_max_scaler.fit(train_rep.values)

TrainCOHORT[repCols] = min_max_scaler.transform(train_rep.values)
TestCOHORT[repCols] = min_max_scaler.transform(test_rep.values)
ValCOHORT[repCols] = min_max_scaler.transform(val_rep.values)

TrainCOHORT[repColsNext] = TrainCOHORT.groupby('pat_enc_csn_id')[repCols].shift(-1).fillna(0)
TestCOHORT[repColsNext] = TestCOHORT.groupby('pat_enc_csn_id')[repCols].shift(-1).fillna(0)
ValCOHORT[repColsNext] = ValCOHORT.groupby('pat_enc_csn_id')[repCols].shift(-1).fillna(0)


TrainCOHORT = TrainCOHORT[TrainCOHORT.vent_flag == 1]
TestCOHORT = TestCOHORT[TestCOHORT.vent_flag == 1]
ValCOHORT = ValCOHORT[ValCOHORT.vent_flag == 1]

# mark for terminal state

TrainCOHORT['ter'] = 0
TrainCOHORT.loc\
[TrainCOHORT.reset_index()\
 .groupby('pat_enc_csn_id', as_index = True)['index'].last().tolist(),'ter'] = 1

TestCOHORT['ter'] = 0
TestCOHORT.loc\
[TestCOHORT.reset_index()\
 .groupby('pat_enc_csn_id', as_index = True)['index'].last().tolist(),'ter'] = 1

ValCOHORT['ter'] = 0
ValCOHORT.loc\
[ValCOHORT.reset_index()\
 .groupby('pat_enc_csn_id', as_index = True)['index'].last().tolist(),'ter'] = 1

# ventilator allocated: -0.1, alive: +1, death: -1, not allocated: -1
TrainCOHORT['reward'] = -0.1
TrainCOHORT.loc[TrainCOHORT.ter == 1, 'reward'] = (TrainCOHORT.loc[TrainCOHORT.ter == 1, 'death'] - 0.5)*(-2)

ValCOHORT['reward'] = -0.1
ValCOHORT.loc[ValCOHORT.ter == 1, 'reward'] = (ValCOHORT.loc[ValCOHORT.ter == 1, 'death'] - 0.5)*(-2)
import random
import copy
from tqdm import tqdm

TrainCOHORT['day'] = TrainCOHORT.groupby('pat_enc_csn_id').cumcount() + 1

TrainCOHORT['p_t_index'] = TrainCOHORT['pat_enc_csn_id'].astype(str) + '|'+ TrainCOHORT['day'].astype(str)

TrainCOHORT = TrainCOHORT.set_index('p_t_index')

patients = TrainCOHORT.pat_enc_csn_id.unique().tolist()
patientsLength = TrainCOHORT.groupby('pat_enc_csn_id',as_index=False)['day'].max().set_index('pat_enc_csn_id')['day'].to_dict()



from torch.utils.data import Dataset
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader



import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# model framework
class BIG_Q_ATTN(nn.Module):
    def __init__(self, state_dim, num_actions):
        super(BIG_Q_ATTN, self).__init__()
        
        self.LayerNorm = nn.LayerNorm(1024, eps = 1e-8)
        
        self.att = nn.MultiheadAttention(embed_dim= 1024, num_heads = 16, batch_first=True)
        
        self.q1 = nn.Linear(state_dim, 1024 ) # ffwd
        
        self.q2 = nn.Linear(1024, 1024)
        
        self.q3 = nn.Linear(1024, num_actions)
        
        
        
        
        
    def forward(self, state, mask, action = None):
        
        #embd
        embed = F.relu(self.q1(state))
        
        # attn, add, norm
        self.trans_sub = self.LayerNorm(
            self.att(embed, embed, embed, key_padding_mask = (1-mask).bool())[0] + embed)
        
        # ffwd, add, norm
        self.trans_out = self.LayerNorm(
            F.relu(self.q2(self.trans_sub)) + self.trans_sub)
        
        
        self.output = self.q3(self.trans_out)
        
        # batch BY seq BY num_actions

        if action is None:
            # for next state, return heads
            return self.output # 32 (batch_size) X (bed number) X 2; Q(a=1) > Q(a=0) 
        else:
            # for current state, return Q 
            return (torch.stack((torch.where(mask == 0, 0, (1 - action).long())
                                 , action), dim = 2) * self.output).sum(axis = -1).sum(axis = -1).reshape(-1,1) #\
                        #/ mask.sum(axis = 1)).reshape(-1,1)

# TDDQN model
class TDDQN(object):
    def __init__(
        self, 
        num_actions=2,
        state_dim=len(repCols)+4,
        device=device,
        discount=0.9,
        optimizer="Adam",
        lr=3e-5,
        polyak_target_update=False,
        target_update_frequency=5e2,
        ckpt_frequency = 500,
        tau=0.005,
#         initial_eps = 1,
#         end_eps = 0.001,
#         eps_decay_period = 25e4,
#         eval_eps=0.001,
    ):
    
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Determine network type
        #self.Q = TRANSFORMER_Q(state_dim, num_actions).to(self.device)
        
        self.Q = BIG_Q_ATTN(state_dim, num_actions).to(self.device)
        #self.Q.load_state_dict(torch.load('/share/fsmresfiles//ylo7832/covidVent/exp2/ckpt/199.pth'), strict=False)
        

        # double DQN    
        self.Q_target = copy.deepcopy(self.Q)
        #self.Q_optimizer = getattr(torch.optim, optimizer)(self.Q.parameters(), **optimizer_parameters)
        self.Q_optimizer = torch.optim.Adam(self.Q.parameters(), lr = lr)
        self.discount = discount

        # Target update rule
        self.maybe_update_target = self.polyak_target_update if polyak_target_update else self.copy_target_update
        self.target_update_frequency = target_update_frequency
        self.tau = tau
        self.num_actions = num_actions
        self.iterations = 0

    def train(self, state, action, next_state, reward, mask, next_mask, next_who_was_on):
        # Sample replay buffer

        # Compute the target Q value
        with torch.no_grad():
            next_action_heads = self.Q(next_state, next_mask)
            #mask the invalida actions, batch x bed x 2
            next_action_values = torch.where(next_mask == 0, -np.inf,
                                       (next_action_heads[:,:,1] - next_action_heads[:,:,0]).double())
            
            
            next_action_values = torch.where(next_who_was_on == 1, np.inf,
                           (next_action_values).double())
            # sort value
            next_action_sort_pos = CAP + 23  - torch.argsort(torch.argsort(next_action_values, dim = -1, descending = False), 
                                                dim = -1, descending = False)
            # select 
            next_action_capactiy = torch.clip(next_mask.sum(axis = 1), max = CAP)\
            .unsqueeze(0).transpose(0,1).repeat(1,CAP+24)

            # bool output for next actions
            next_action = (next_action_sort_pos < next_action_capactiy).float()
            
            # convert 0 to [1,0]; 1 to [1,0] when not mask, convert 0 or 1 all to 0 when mask
            next_action = torch.stack((torch.where(next_mask == 0, 0, (1 - next_action).long()),
                                        next_action), dim = 2)
            
            next_action_heads = self.Q_target(next_state, next_mask)

            q = (next_action * next_action_heads).sum(axis = -1).sum(axis = -1) #/ mask.sum(axis = 1)
            
            target_Q = reward + self.discount * q.reshape(-1,1)

        # Get current Q estimate
        current_Q= self.Q(state, mask, action)
        
        # Compute Q loss
        Q_loss = F.smooth_l1_loss(current_Q, target_Q)

        # Optimize the Q
        self.Q_optimizer.zero_grad()
        Q_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.Q.parameters(), 1.)
        self.Q_optimizer.step()

        # Update target network by polyak or full copy every X iterations.
        self.iterations += 1
        self.maybe_update_target()
        
        return Q_loss.item()#, i_loss.item(), Q_loss.item()
    def evaluate(self, state, action, next_state, reward, mask, next_mask, next_who_was_on):
        # Sample replay buffer

        # Compute the target Q value
        with torch.no_grad():
            next_action_heads = self.Q(next_state, next_mask)
            #mask the invalida actions, batch x bed x 2
            next_action_values = torch.where(next_mask == 0, -1e8,
                                       (next_action_heads[:,:,1] - next_action_heads[:,:,0]).double())
            
            
            next_action_values = torch.where(next_who_was_on == 1, 1e8,
                           (next_action_values).double())
            # sort value
            next_action_sort_pos = CAP + 23  - torch.argsort(torch.argsort(next_action_values, dim = -1, descending = False), 
                                                dim = -1, descending = False)
            # select 
            next_action_capactiy = torch.clip(next_mask.sum(axis = 1), max = CAP)\
            .unsqueeze(0).transpose(0,1).repeat(1,CAP+24)

            # bool output for next actions
            next_action = (next_action_sort_pos < next_action_capactiy).float()
            
            # convert 0 to [1,0]; 1 to [1,0] when not mask, convert 0 or 1 all to 0 when mask
            next_action = torch.stack((torch.where(next_mask == 0, 0, (1 - next_action).long()),
                                        next_action), dim = 2)
            
            next_action_heads = self.Q_target(next_state, next_mask)

            q = (next_action * next_action_heads).sum(axis = -1).sum(axis = -1) #/ mask.sum(axis = 1)
            
            target_Q = reward + self.discount * q.reshape(-1,1)

        # Get current Q estimate
            current_Q= self.Q(state, mask, action)
        
        # Compute Q loss
            Q_loss = F.smooth_l1_loss(current_Q, target_Q)
        
        return Q_loss.item()#, i_loss.item(), Q_loss.item()


    def polyak_target_update(self):
        for param, target_param in zip(self.Q.parameters(), self.Q_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


    def copy_target_update(self):
        if self.iterations % self.target_update_frequency == 0:
             self.Q_target.load_state_dict(self.Q.state_dict())


    # return the actions when testing / validation            
    def value(self, state, mask, who_was_on):
            with torch.no_grad():
                action_heads = self.Q(state, mask)
                #mask the invalid actions
                action_values = torch.where(mask == 0, -1e8,
                                           (action_heads[:,:,1] - action_heads[:,:,0]).double())

                action_values = torch.where(who_was_on == 1, 1e8,
                                            (action_values).double())
                
#                 # sort value
                action_sort_pos = CAP + 23 - torch.argsort(torch.argsort(action_values, dim = -1, descending = False), 
                                                dim = -1, descending = False)
                # select 
                action_capactiy = torch.clip(mask.sum(axis = 1), max = CAP)\
                .unsqueeze(0).transpose(0,1).repeat(1,CAP+24)

                # bool output for next actions
                actions = (action_sort_pos < action_capactiy)
            
            return actions.detach().cpu().numpy()
                
            #return actions.detach().cpu().numpy()
                
#policy = TDDQN()


# calculate KL divergence
def cal_kl(p,q):
    return  stats.entropy(np.array(p)/np.array(p).sum(), np.array(q)/np.array(q).sum())

# show allocation rate of each races
def show_race_dis(p,q):
    return  print(np.divide(p,q))

policy = TDDQN()


# simulator generate MDP
class OnlineSimulator(Dataset):
    def __init__(self, mdp_len = 16000, batch_size = 32):
        super().__init__()
        self.DIM = CAP + 24
        self.mdp_len = mdp_len
        self.batch_size = batch_size
        #self.s, self.a, self.r, self.m, self.w = [],[],[],[],[]
        self._sample_MDP()
        #self._for_dataloader()
        
        
        
    

        
        
    def __len__(self):
        return self.mdp_len
       
        
    
    def __getitem__(self, idx):
        return self.s[idx], self.a[idx], self.s[idx+1], self.r[idx], self.m[idx], self.m[idx+1], self.w[idx+1]
    
    
    # pad trailing zeros when there is vacant beds    
    def _pad(self, toPad):
        return torch.nn.functional.pad(toPad, pad = (0,0,0,self.DIM - toPad.size()[0]), mode = 'constant', value = 0)
    

    def _sample_MDP(self):
        
        random.seed(CAP+10086)
        np.random.seed(CAP+60202)
        

        inHouse = set()
        inPool = set(copy.deepcopy(patients))


        states_torch = []
        actions_torch = []
        masks_torch = []
        whowason_torch = []
        rewards_torch = []

        states = []
        actions = []
        rewards = []
        whoWasOn = [0] # no body was on at t = 1
        races = [] # number of patients
        racesYes = [] # number of allocated patients
        kl = []


        current_state = []
        current_action = []

        #init, t = 1

        # take new patient
        newPaNO = np.random.poisson(12)
        newPaIDs = random.sample(list(inPool), newPaNO)

        # exclude these new patients from pool, add them to house
        inPool.difference_update(newPaIDs)
        inHouse.update(newPaIDs)

        # add new patients to current state
        current_state += [str(p)+'|1' for p in newPaIDs]

        if len(inHouse) >= CAP: # over cap
            new_action = [1]*(newPaNO - (len(inHouse) - CAP)) + (len(inHouse) - CAP)*[0]
            random.shuffle(new_action)
            current_action += new_action
        else: # under cap
            current_action = [1]*len(current_state)


        states.append(current_state)
        actions.append(current_action)


        # add the existing allocation results to state features
        races.append(TrainCOHORT.loc[current_state][['isAsian', 'isBlack', 'isHispanic', 'isWhite']].sum().tolist())
        racesYes.append(TrainCOHORT.loc[current_state][['isAsian', 'isBlack', 'isHispanic', 'isWhite']].sum().tolist())
        
        raceFeat = np.tile(np.divide(racesYes[-1],races[-1], where=np.array(races[-1])!=0), (len(current_state),1))
        

        kl.append( stats.entropy(np.array(races[-1])/np.array(races[-1]).sum(), 
                                 np.array(racesYes[-1])/np.array(racesYes[-1]).sum()) )
        #put them on torch

        states_torch.append(self._pad(torch.from_numpy(np.hstack((TrainCOHORT.loc[current_state][repCols].values,
                                                                  raceFeat)).astype(np.float32)).to(device)))
        actions_torch.append(self._pad(torch.from_numpy(np.array(current_action).astype(np.float32)).unsqueeze(-1).to(device)).squeeze())
        masks_torch.append(self._pad(torch.from_numpy(np.array(len(current_action)*[1]).astype(np.float32)).to(device).unsqueeze(-1)).squeeze())
        whowason_torch.append(self._pad(torch.from_numpy(np.array(whoWasOn[-1]*[1]).astype(np.float32)).to(device).unsqueeze(-1)).squeeze())
        rewards_torch.append( torch.tensor(TrainCOHORT.loc[current_state]['reward'][np.array(current_action) == 1].sum() \
                           - (np.array(current_action) == 0).sum() ).to(device))


        for t in trange(1,self.mdp_len+1):
            # query previous state
            previous_state = states[-1]
            previous_action = actions[-1]

            previous_races = races[-1]
            previous_racesYes = racesYes[-1]
            
            

            current_state = []
            current_action = []


            # progress to current 
            for sst, a in zip(previous_state, previous_action):
                st = int(sst.split('|')[1]) # parse id|day
                s = int(sst.split('|')[0])
                if (st+1 <= patientsLength[s]) and (a == 1):
                    current_state.append(str(s)+str('|')+ str(st+1))
                else:
                    # s is discharged, add it back to pool
                    inHouse.discard(s)
                    inPool.add(s)

            # give all vents to them becasue they were here, should be replace with online
            current_action = [1]*len(current_state)

            whoWasOn.append(len(current_state))

            # take new patient
            newPaNO = max(CAP+1 - whoWasOn[-1], min(np.random.poisson(12), 24)) # compete every day
            newPaIDs = random.sample(list(inPool), newPaNO)

            # exclude these new patients from pool, add them to house
            inPool.difference_update(newPaIDs)
            inHouse.update(newPaIDs)

            # add new patients to current state
            current_state += [str(p)+'|1' for p in newPaIDs]
            
            #raceFeat = np.tile(np.hstack((racesYes[-1],races[-1])), (len(current_state),1))
            raceFeat = np.tile(np.divide(racesYes[-1],races[-1], where=np.array(races[-1])!=0), (len(current_state),1))

            states_torch.append(self._pad(torch.from_numpy(np.hstack((TrainCOHORT.loc[current_state][repCols].values,
                                                                      raceFeat)).astype(np.float32)).to(device)))

            masks_torch.append(self._pad(torch.from_numpy(np.array(len(current_state)*[1])\
                                                     .astype(np.float32)).to(device).unsqueeze(-1)).squeeze())

            whowason_torch.append(self._pad(torch.from_numpy(np.array(whoWasOn[-1]*[1])\
                                                        .astype(np.float32)).to(device).unsqueeze(-1)).squeeze())






            current_action = policy.value(states_torch[-1].unsqueeze(0), 
                                          masks_torch[-1].unsqueeze(0), 
                                          whowason_torch[-1].unsqueeze(0)).astype(int).tolist()[0][:len(current_state)]


            actions_torch.append(self._pad(torch.from_numpy(np.array(current_action)\
                                                       .astype(np.float32)).unsqueeze(-1).to(device)).squeeze())




            states.append(current_state)
            actions.append(current_action)

            current_races = TrainCOHORT.loc[[str(p)+'|1' for p in newPaIDs]][['isAsian', 'isBlack', 'isHispanic', 'isWhite']].sum().tolist()
            current_racesYes = TrainCOHORT.loc[[[str(p)+'|1' for p in newPaIDs][i] for i, a in enumerate(current_action[whoWasOn[-1]:]) if a == 1]]\
            [['isAsian', 'isBlack', 'isHispanic', 'isWhite']].sum().tolist()

            races.append([x + y for x, y in zip(previous_races, current_races)])
            racesYes.append([x + y for x, y in zip(previous_racesYes, current_racesYes)]) 
            
            kl.append( cal_kl( racesYes[-1], races[-1]) )
            
            
            rewards_torch.append( torch.tensor(TrainCOHORT.loc[current_state]['reward'][np.array(current_action) == 1].sum() \
                                - (np.array(current_action) == 0).sum() + np.clip(( - kl[-1])*1e3,0,-5) ).to(device))
            
            
        self.races = races
        self.racesYes = racesYes
        self.s = torch.stack(states_torch)
        self.a = torch.stack(actions_torch)
        self.r = torch.stack(rewards_torch).unsqueeze(-1)
        self.m = torch.stack(masks_torch)
        self.w = torch.stack(whowason_torch)
        
        self.kl = kl
            
            
import shutil
import os

# initilize tensorboard for saving ckpt and metrics    
from tensorboardX import SummaryWriter
import os, shutil
save_dir = '/share/fsmresfiles/ylo7832/covidVent/KL/'+str(CAP)+'/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
if os.path.exists(save_dir  + 'tb'):
    shutil.rmtree(save_dir + 'tb', ignore_errors=True)
writer = SummaryWriter(save_dir  + 'tb')
if os.path.exists(save_dir+'ckpt'):
    shutil.rmtree(save_dir + 'ckpt')
    os.makedirs(save_dir+'ckpt')
else:
    os.makedirs(save_dir+'ckpt')

from tqdm import trange
          
            
        


from tqdm import trange
q_loss = 0

res = {'step':[], 'val':[], 'test':[], 'val_KL':[], 'test_KL':[]}

from random import sample
for step in range(0, 30001): 
    # evaluate every 500 and generate new MDP for reply buffer steps
    if (step % 500 == 0):
        
        train = OnlineSimulator(16000)
        trainL = DataLoader(train, batch_size = 32, shuffle=True)
        
        stepIn500 = int(step / 500)
        

        # test the trained models on validation set
        
        deaths = ValCOHORT[ValCOHORT.death == 1]['pat_enc_csn_id'].unique().tolist()
        
        dailyVENT = ValCOHORT.copy(deep=True)
        dailyVENT['whoWasOn'] = 0
        whoWasOn = []
        
        
        earlyDeaths = []
        
        races = [[0,0,0,0]]
        racesYes = [[0,0,0,0]]
        


        for t in pd.date_range(start='2021-07-15', end='2021-10-15'):

            needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
            
            raceFeat = np.tile(np.divide(racesYes[-1],races[-1], where=np.array(races[-1])!=0), (len(needOfTheDay),1))

                
            #raceFeat = np.tile(np.hstack((racesYes[-1],races[-1])), (len(needOfTheDay),1))
            
            
            
            with torch.no_grad():
                state = torch.nn.functional.pad(torch.from_numpy(np.hstack((needOfTheDay[repCols].values,
                                                                               raceFeat)).astype(np.float32)), 
                                                     pad = (0,0,0,CAP+24 - len(needOfTheDay)), 
                                                mode = 'constant', value = 0).unsqueeze(0).to(device)
                mask = torch.from_numpy(np.array([1]*len(needOfTheDay) + 
                                                 [0]*(CAP+24 - len(needOfTheDay)))).unsqueeze(0).to(device)
                
                nextwhowason = torch.from_numpy(np.array(list(needOfTheDay['whoWasOn'].values)+\
                                                         [0]*(CAP+24 - len(needOfTheDay)))).unsqueeze(0).to(device)

                needOfTheDay['on_off'] = policy.value(state,mask,nextwhowason)[0][:len(needOfTheDay)]
                

                                   
                # fair feature                   
                previous_races = races[-1]
                current_races = needOfTheDay.query('whoWasOn == 0')[['isAsian', 'isBlack',
                                                            'isHispanic', 'isWhite']].sum().tolist()
                races.append([x + y for x, y in zip(previous_races, current_races)])

                previous_racesYes = racesYes[-1]
                current_racesYes = needOfTheDay.query('on_off == True and whoWasOn == 0')[['isAsian', 'isBlack', 
                                                          'isHispanic', 'isWhite']].sum().tolist()
                racesYes.append([x + y for x, y in zip(previous_racesYes, current_racesYes)])


            excludePAs = needOfTheDay[needOfTheDay.on_off == False].pat_enc_csn_id.unique().tolist()
            
            whoWasOn = needOfTheDay[needOfTheDay.on_off == True].pat_enc_csn_id.unique().tolist()

            earlyDeaths += excludePAs

            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
            
            dailyVENT['whoWasOn'] = 0
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 1

        deathTally = len(set(earlyDeaths + deaths))
        
        res['step'].append(stepIn500)
        res['val'].append(deathTally)
        
        #print(deathTally)
        res['val_KL'].append( cal_kl(races[-1], racesYes[-1]) )
        #writer.add_scalar('Val/deaths', deathTally, stepIn500)
        
        # test the trained models on testing set
        deaths = TestCOHORT[TestCOHORT.death == 1]['pat_enc_csn_id'].unique().tolist()
        
        dailyVENT = TestCOHORT.copy(deep=True)
        dailyVENT['whoWasOn'] = 0
        whoWasOn = []
        earlyDeaths = []
        
        races = [[0,0,0,0]]
        racesYes = [[0,0,0,0]]

        for t in pd.date_range(start='2021-10-15', end='2023-01-15'):

            needOfTheDay = dailyVENT[dailyVENT.day_bucket_starts == t]
            

            #raceFeat = np.tile(np.hstack((racesYes[-1],races[-1])), (len(needOfTheDay),1))
            raceFeat = np.tile(np.divide(racesYes[-1],races[-1], where=np.array(races[-1])!=0), (len(needOfTheDay),1))
                

            with torch.no_grad():
#                 state = torch.nn.functional.pad(torch.from_numpy(needOfTheDay[repCols].values.astype(np.float32)), 
#                                                      pad = (0,0,0,CAP+24 - len(needOfTheDay)), 
#                                                 mode = 'constant', value = 0).unsqueeze(0).to(device)
                
                
                state = torch.nn.functional.pad(torch.from_numpy(np.hstack((needOfTheDay[repCols].values,
                                                                               raceFeat)).astype(np.float32)), 
                                                     pad = (0,0,0,CAP+24 - len(needOfTheDay)), 
                                                mode = 'constant', value = 0).unsqueeze(0).to(device)
                mask = torch.from_numpy(np.array([1]*len(needOfTheDay) + 
                                                 [0]*(CAP+24 - len(needOfTheDay)))).unsqueeze(0).to(device)
                
                nextwhowason = torch.from_numpy(np.array(list(needOfTheDay['whoWasOn'].values)+\
                                                         [0]*(CAP+24 - len(needOfTheDay)))).unsqueeze(0).to(device)

                needOfTheDay['on_off'] = policy.value(state,mask, nextwhowason)[0][:len(needOfTheDay)]
                
                #if needOfTheDay['on_off'].nunique() == 2:
                
                # fair feature                   
                previous_races = races[-1]
                current_races = needOfTheDay.query('whoWasOn == 0')[['isAsian', 'isBlack',
                                                            'isHispanic', 'isWhite']].sum().tolist()
                races.append([x + y for x, y in zip(previous_races, current_races)])

                previous_racesYes = racesYes[-1]
                current_racesYes = needOfTheDay.query('on_off == True and whoWasOn == 0')[['isAsian', 'isBlack', 
                                                          'isHispanic', 'isWhite']].sum().tolist()
                racesYes.append([x + y for x, y in zip(previous_racesYes, current_racesYes)])
                                   

            excludePAs = needOfTheDay[needOfTheDay.on_off == False].pat_enc_csn_id.unique().tolist()
            
            whoWasOn = needOfTheDay[needOfTheDay.on_off == True].pat_enc_csn_id.unique().tolist()

            earlyDeaths += excludePAs

            dailyVENT = dailyVENT[~dailyVENT.pat_enc_csn_id.isin(earlyDeaths)]
            
            dailyVENT['whoWasOn'] = 0
            dailyVENT.loc[dailyVENT.pat_enc_csn_id.isin(whoWasOn), 'whoWasOn'] = 1

        deathTally = len(set(earlyDeaths + deaths))
        
        res['test'].append(deathTally)
        
        #print(deathTally)
        res['test_KL'].append(cal_kl(races[-1], racesYes[-1]) )

        

        
        q_loss = 0
        
        torch.save(policy.Q.state_dict(), save_dir+'ckpt/'+str(stepIn500)+'.pth')

        
    state, action, next_state, reward, mask, next_mask, next_who_was_on = next(iter(trainL))
    q_loss_running = policy.train(state, action, next_state, reward, mask, next_mask, next_who_was_on)
    q_loss+=q_loss_running
    
    
res = pd.DataFrame(res)
res.to_csv(save_dir+'res.csv')
