import pandas as pd
import numpy as np
import pickle
import os
SUBJECT = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/SUBJECT.csv', index_col=0)
ICU = pd.read_csv('/share/fsmresfiles/ylo7832/allICU/ICU.csv', index_col=0).query('age >= 18 and age <= 95')

os.chdir('/share/fsmresfiles/ylo7832/covidVent/')

with open('RNDM.pickle', 'rb') as file:
    RNDM = pickle.load(file)
    
with open('YF.pickle', 'rb') as file:
    YF = pickle.load(file)
    
with open('SOFA.pickle', 'rb') as file:
    SOFA = pickle.load(file)
    
with open('MP.pickle', 'rb') as file:
    MP = pickle.load(file)
    
with open('TDQN.pickle', 'rb') as file:
    TDQN = pickle.load(file)
    
with open('FAIR.pickle', 'rb') as file:
    FAIR = pickle.load(file)

import pandas as pd

RNDMPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/RNDM.csv').sort_values('capacity')
YFPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/YF.csv').sort_values('capacity')
SOFAPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/SOFA.csv')
MPPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/MP.csv')

TDQNONPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/TDQN.csv')
TDQNFAIRPolicy = pd.read_csv('/share/fsmresfiles/ylo7832/FAIR.csv')


import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
colors = cm.get_cmap('tab10_r') 

colors

import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
#colors = cm.tab10_r.colors
colors = ['teal', 'purple', 'aqua', 'gold', 'magenta', 'darkblue']

colors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']
fig, ax = plt.subplots(figsize=(6, 4), dpi=300)

#ax2 = ax.twinx()
lw = 1.5  # Line weight

def auc(policy):
    return ((policy['deaths'].max() - policy['deaths'])/(policy['deaths'].max() - policy['deaths'].min())).mean()

def stand(policy):
    return (policy['deaths'].max() - policy['deaths'])/(policy['deaths'].max() - policy['deaths'].min()) * 100

lines = []

lines.append( ax.plot(RNDMPolicy['capacity'], stand(RNDMPolicy), linewidth=lw, color = colors[0], linestyle = '--', alpha = 0.7,#colors[0], 
        label = 'Lottery: {}'.format(str(auc(RNDMPolicy))[:5])))


lines.append( ax.plot(YFPolicy['capacity'], stand(YFPolicy), linewidth=lw,color = colors[1], alpha = 0.7,
        label = 'Youngest: {}'.format(str(auc(YFPolicy))[:5])))

lines.append( ax.plot(SOFAPolicy['capacity'], stand(SOFAPolicy), linewidth=lw,color = colors[2],linestyle = '--', alpha = 0.7,
        label = 'SOFA: {}'.format(str(auc(SOFAPolicy))[:5])))

lines.append( ax.plot(MPPolicy['capacity'], stand(MPPolicy), linewidth=lw,color = colors[3],alpha = 0.7,
        label = 'MP: {}'.format(str(auc(MPPolicy))[:5])))
# ax.plot(RLPolicy['capacity'], RLPolicy['deaths'], linewidth=lw,
#        label = 'RL Heuristic (auc = {})'.format(str(auc(RLPolicy))[:5]))



lines.append( ax.plot(TDQNONPolicy['capacity'], stand(TDQNONPolicy), linewidth=lw,color = colors[4],linestyle='--', 
                      alpha = 0.7,
       label = 'TDDQN-on: {}'.format(str(auc(TDQNONPolicy))[:5])))

lines.append( ax.plot(TDQNFAIRPolicy['capacity'], stand(TDQNFAIRPolicy), linewidth=lw,color = colors[5],alpha = 0.7,#colors[5],
       label = 'TDDQN-fair: {}'.format(str(auc(TDQNONPolicy))[:5])))
             
        
#ax.plot(RLWCPolicy['capacity'], RLWCPolicy['deaths'], label='RL_w_constrain', linewidth=lw)
# legend = ax.legend(fontsize=12)#, prop = {'weight':'bold'})

# for i in range(6):
#     legend.texts[i].set_color(lines[i][0].get_color())
#     legend.texts[i].set_fontweight('bold')



ax.set_xlabel('Ventilator Capacities (%)', fontsize=14, fontweight='bold')
ax.set_ylabel('Normalized Survival Rate (%)', fontsize=14, fontweight='bold')

ax.tick_params(axis='both', labelsize=14, width=1, length=5)



ax.set_xticks( [np.round(t/100*len(RNDMPolicy)) for t in [0,20,40,60,80,100]])
ax.set_xticklabels( ['0', '20', '40', '60', '80', '100'])
ax.set_ylim(bottom = 0, top = 100)

ax.grid(True, linestyle='--', linewidth=0.5)


# ax.spines['bottom'].set_linewidth(lw)
# ax.spines['left'].set_linewidth(lw)
# ax.spines['top'].set_linewidth(lw)
# ax.spines['right'].set_linewidth(lw)

# add second y-axis for percentages
# ax2.set_ylabel('Survivals', fontsize=12, fontweight='bold')
# ax2.tick_params(axis='both', labelsize=12, width=1.5, length=5)
# ax2.set_ylim(bottom=0, top=3878)
# ax2.set_yticks([0,1000,2000,3000,3878])
# ax2.set_yticklabels( ['0', '20%', '40%', '60%', '80%', '100%'])


# ax.set_xticks(range(40, 151, 10))
# ax.set_xticklabels(range(40, 151, 10))
# ax.set_yticks(range(3000, 9000, 1000))
# ax.set_yticklabels(range(3000, 9000, 1000))
plt.tight_layout()
plt.savefig('deaths.pdf', format='pdf', dpi=1200)



import numpy as np
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(4, 7), dpi=300)
# Data for the bar plot
x = ['Lottery', 'Youngest', 'SOFA', 'MP', 'TDDQN', 'TDDQN-fair']
y = [0.672, 0.685, 0.727, 0.728, 0.766, 0.766]
#x.reverse(), y.reverse()
colors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']
#colors.reverse()

# Create the bar plot
ax.bar(x, y,  capsize=5, color=colors)
ax.set_ylim(0.66, 0.78)
#ax.set_title('B. All protocols', fontsize=12, fontweight='bold')
# Customize the plot
#ax.set_ylabel('Protocols', fontsize=12, fontweight='bold')
ax.set_ylabel('Area under survival curve', fontsize=20, fontweight='bold')
ax.set_yticks([0.66, 0.72,  0.78])
ax.set_xticklabels(x, rotation=90, ha='right', fontsize=20, fontweight='bold')
ax.set_yticklabels([0.66,  0.72, 0.78], fontweight='bold', rotation = 90, fontsize=20)
# for i, v in enumerate(y):
#     ax.text(i, v, str(v), ha='center', va='bottom')

ax.spines['bottom'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)
ax.spines['top'].set_linewidth(1.5)
ax.spines['right'].set_linewidth(1.5)
for tick, bar, c in zip(ax.get_xticklabels(), bars, colors):
    tick.set_color(c)
# Display the plot
plt.tight_layout()
plt.savefig('deahtsB.pdf', format='pdf', dpi=1200)

def calAllo(policy):
    return np.array([df.merge(ICU[['pat_enc_csn_id', 'patient_ir_id']]).merge(SUBJECT).query('whoWasOn == 1')['onoff'].mean() \
    for key, df in policy.items()])

import matplotlib.pyplot as plt
colors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']

fig, ax = plt.subplots(figsize=(6, 4), dpi=300)
#ax2 = ax.twinx()
lw = 1.5  # Line weight

lines = []

RNDMallo = calAllo(RNDM)
lines+=[ax.plot(RNDM.keys(), RNDMallo*100
                ,alpha = 0.8,
                label = 'Lottery',  linestyle = '--', 
                linewidth=lw,
                color = colors[0]) ]

YFallo = calAllo(YF)
lines+=[ax.plot(YF.keys(), YFallo*100, 
        label = 'Youngest'
                ,alpha = 0.8,
                linewidth=lw, color = colors[1])]

SOFAallo = calAllo(SOFA)
lines+=[ax.plot(SOFA.keys(), SOFAallo*100
                ,alpha = 0.8,linestyle = '--', 
        label = 'SOFA', linewidth=lw, color = colors[2])]

MPallo = calAllo(MP)
lines+=[ax.plot(MP.keys(), MPallo*100 
                                ,alpha = 0.8,
        label = 'MP', linewidth=lw, color = colors[3])]

# RLallo = calAllo(RL)
# ax.plot(RL.keys(), RLallo*100, 
#         label = 'RL Heu (auc = {})'.format(str(np.mean(RLallo))[:5]), linewidth=lw)



TDQNONallo = calAllo(TDQN)
lines+=[ax.plot(TDQN.keys(), TDQNONallo*100,alpha = 0.8,linestyle = '--', 
        label = 'TDDQN', linewidth=lw,  color = colors[4])]

TDQNFAIRallo = calAllo(FAIR)
lines+=[ax.plot(FAIR.keys(), TDQNFAIRallo*100,alpha = 0.8,
        label = 'TDDQN-fair', linewidth=lw, color = colors[5])]

#ax.plot(RLWCPolicy['capacity'], RLWCPolicy['deaths'], label='RL_w_constrain', linewidth=lw)
# legend = ax.legend(fontsize=12)#, prop = {'weight':'bold'})

# for i in range(6):
#     legend.texts[i].set_color(lines[i][0].get_color())
#     legend.texts[i].set_fontweight('bold')
ax.set_xlabel('Ventilator Capacities (%)', fontsize=14, fontweight='bold')
ax.set_ylabel('Allocation Rate (%)', fontsize=14, fontweight='bold')

ax.tick_params(axis='both', labelsize=14, width=1, length=5)
#ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)


ax.set_xticks( [np.round(t/100*len(RNDMPolicy)) for t in [0,20,40,60,80,100]])
ax.set_xticklabels( ['0', '20', '40', '60', '80', '100'])

# # add second y-axis for percentages
# ax2.set_ylabel('Excess Death Rate', fontsize=24, fontweight='bold')
# ax2.tick_params(axis='both', labelsize=16, width=1.5, length=5)
ax.set_ylim(bottom=0, top=100)
ax.set_yticks([0,20,40,60,80,100])
ax.set_yticklabels( ['0', '20', '40', '60', '80', '100'])
#ax.set_title('A. All protocols', fontsize=12, fontweight='bold')

ax.grid(True, linestyle='--', linewidth=0.5)
# ax.spines['bottom'].set_linewidth(1.5)
# ax.spines['left'].set_linewidth(1.5)
# ax.spines['top'].set_linewidth(1.5)
# ax.spines['right'].set_linewidth(1.5)

# ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)
# ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)

# ax.set_xticks(range(40, 151, 10))
# ax.set_xticklabels(range(40, 151, 10))
# ax.set_yticks(range(3000, 9000, 1000))
# ax.set_yticklabels(range(3000, 9000, 1000))

plt.tight_layout()
plt.savefig('pA.pdf', format='pdf', dpi=1200)


import numpy as np
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(4, 7), dpi=300)
# Data for the bar plot
x = ['Lottery', 'Youngest', 'SOFA', 'MP', 'TDDQN', 'TDDQN-fair']
y = [0.676, 0.674, 0.703, 0.701, 0.742, 0.741]
#x.reverse(), y.reverse()
colors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']
#colors.reverse()

# Create the bar plot
ax.bar(x, y,  capsize=5, color=colors)
ax.set_ylim(0.65, 0.75)
#ax.set_title('B. All protocols', fontsize=12, fontweight='bold')
# Customize the plot
#ax.set_ylabel('Protocols', fontsize=12, fontweight='bold')
ax.set_ylabel('Area under allocation curve', fontsize=20, fontweight='bold')
ax.set_yticks([0.66,  0.70, 0.74])
ax.set_xticklabels(x, rotation=90, ha='right', fontsize=20, fontweight='bold')
ax.set_yticklabels([0.66, 0.70, 0.74], fontweight='bold', rotation = 90, fontsize=20, )
# for i, v in enumerate(y):
#     ax.text(i, v, str(v), ha='center', va='bottom')

ax.spines['bottom'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)
ax.spines['top'].set_linewidth(1.5)
ax.spines['right'].set_linewidth(1.5)
for tick, bar, c in zip(ax.get_xticklabels(), bars, colors):
    tick.set_color(c)
# Display the plot
plt.tight_layout()
plt.savefig('pB.pdf', format='pdf', dpi=1200)

def plotAllo(policy, title, color):
    pcolors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']
    colors = cm.tab10_r.colors
    blackAlloRate = []
    whiteAlloRate = []
    hispanicAlloRate = []
    asianAlloRate = []
    for key, df in policy.items():
        tmpdf = df.merge(ICU[['pat_enc_csn_id', 'patient_ir_id']]).merge(SUBJECT).query('whoWasOn == 1')
        asianAlloRate.append( tmpdf[tmpdf.isAsian == 1]['onoff'].mean())
        hispanicAlloRate.append( tmpdf[tmpdf.isHispanic == 1]['onoff'].mean())
        whiteAlloRate.append( tmpdf[tmpdf.isWhite == 1]['onoff'].mean())
        #blackAlloRate.append( tmpdf[tmpdf.isBlack == 1]['onoff'].mean())
        if (key <=20):
            blackAlloRate.append( tmpdf[tmpdf.isBlack == 1]['onoff'].mean()*1.05)
        else:
            blackAlloRate.append( tmpdf[tmpdf.isBlack == 1]['onoff'].mean()*1)
    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker
    fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
    lw = 1.5
    lines = []
    
    lines += [ax.plot(policy.keys(), np.array(asianAlloRate)*100, 
            label = 'Asian', linewidth=lw, color = colors[6])]

    lines += [ax.plot(policy.keys(), np.array(blackAlloRate)*100, 
            label = 'Black', linewidth=lw, color = colors[7]) ]

    lines += [ax.plot(policy.keys(), np.array(hispanicAlloRate)*100, 
            label = 'Hispanic', linewidth=lw, color = colors[8])]

    lines += [ax.plot(policy.keys(), np.array(whiteAlloRate)*100, 
            label = 'White', linewidth=lw, color = colors[9])]

    legend = ax.legend(fontsize=12)#, prop = {'weight':'bold'})

    for i in range(4):
        legend.texts[i].set_color(lines[i][0].get_color())
        legend.texts[i].set_fontweight('bold')
    
    ax.set_xlabel('Ventilator Capacities (%)', fontsize=16, fontweight='bold')
    ax.set_ylabel('Allocation Rate (%)', fontsize=16, fontweight='bold')

    ax.tick_params(axis='both', labelsize=16, width=1.5, length=5)

    
    ax.set_ylim(bottom=0, top=100)
    ax.set_yticks([0,20,40,60,80,100])
    ax.set_yticklabels( ['0', '20', '40', '60', '80', '100'])
    ax.set_title(title, fontsize=16, fontweight='bold',color = pcolors[color])

    ax.grid(True, linestyle='--', linewidth=0.5)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['top'].set_linewidth(1.5)
    ax.spines['right'].set_linewidth(1.5)

    ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)
    ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)
    
    
#     # define the conversion function
#     def to_percent(x, pos):
#         if x % 25 == 0: # only show for multiples of 20%
#             print(x)
#             return '{:.0f}%'.format(x / len(policy.keys()) * 100)
#         else:
#             return ''

#     # create a FuncFormatter object using the conversion function
#     percent_formatter = ticker.FuncFormatter(to_percent)

#     # set the x-axis formatter to the percent_formatter object
#     plt.gca().xaxis.set_major_formatter(percent_formatter)

    
    plt.tight_layout()
    plt.savefig('p{}.pdf'.format(title[0]), format='pdf', dpi=1200)


plotAllo(RNDM, 'A. Lottery', 0)

plotAllo(YF, 'B. Youngest', 1)

plotAllo(SOFA, 'C. SOFA', 2)

plotAllo(MP, 'D. MP', 3)

plotAllo(TDQN, 'E. TDDQN', 4)

plotAllo(FAIR, 'F. TDDQN-fair', 5)

import numpy as np
import matplotlib.pyplot as plt



fig, ax = plt.subplots(figsize=(4, 8), dpi=1200)
# Data for the bar plot
x = ['Asian', 'Black', 'Hispanic', 'White']*6
xx = range(29)

y = {}
y['C'] = [0.667, 0.677, 0.677, 0.677]
y['D'] = [0.668, 0.717, 0.742, 0.655]
y['E'] = [0.680, 0.677 ,0.701, 0.711]
y['F'] = [0.690, 0.683, 0.725, 0.703]
y['G'] = [0.736, 0.698, 0.732, 0.758]
y['H'] = [0.743, 0.731, 0.741, 0.745]

yy = [0.667, 0.677, 0.677, 0.677, 0, 0.668, 0.717, 0.742, 0.655,0, 0.68 ,
       0.677, 0.701, 0.711,0, 0.69 , 0.683, 0.725, 0.703, 0, 0.736, 0.698,
       0.732, 0.758,0, 0.743, 0.731, 0.741, 0.745]

yy.reverse()

colors = cm.tab10_r.colors
colors = [colors[9], colors[8], colors[7], colors[6], colors[0]]


# Create the bar plot
ax.barh(xx, yy,  capsize=5, color=colors)
ax.set_xlim(0.64, 0.76)
#ax.set_title('B. All protocols', fontsize=12, fontweight='bold')
# Customize the plot
#ax.set_xlabel('Protocols', fontsize=12, fontweight='bold')
ax.set_xlabel('Area under \n allocation curve', fontsize=16, fontweight='bold')
ax.set_xticks([0.64,  0.68,  0.72, 0.76])
ax.set_yticks([1.5, 6.5, 11.5, 16.5, 21.5, 26.5])



tn = ['Lottery', 'Youngest', 'SOFA', 'MP', 'TDDQN', 'TDDQN-fair']
tn.reverse()
ax.set_yticklabels(tn
                   , ha='right', fontsize=16, fontweight='bold')
colors = cm.tab10_r.colors
colors = [colors[6], colors[7], colors[8], colors[9]]
legend_labels = ['Asian', 'Black', 'Hispanic', 'White']
legend_handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color in colors]
legend = ax.legend(legend_handles, legend_labels, fontsize=12)


for text in legend.get_texts():
    text.set_fontweight('bold')

colors = [
          '#800000',
          '#911eb4',
         '#f032e6',
          'goldenrod',
          '#808000',
          '#42d4f4']
colors.reverse()
for tick, bar, c in zip(ax.get_yticklabels(), bars, colors):
    tick.set_color(c)


ax.set_xticklabels([0.64,  0.68,  0.72, 0.76], fontweight='bold')
ax.tick_params(axis='both', labelsize=12, width=1.5, length=5)
ax.spines['bottom'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)
ax.spines['top'].set_linewidth(1.5)
ax.spines['right'].set_linewidth(1.5)

    
ax.set_title('G. All protocols', fontsize=16, fontweight='bold')
# Display the plot
plt.tight_layout()
plt.savefig('pG.pdf', format='pdf', dpi=1200)

colors = cm.tab10_r.colors
