import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from paper_plotting_code.utils import dataPlot, barPlot, getRuntimeBarData, getData

rc('font',**{'family':'serif','serif':['Times']})
# rc('text', usetex=True)


dictExp = {
    'constantCosts': [{'N':5, 'B':4},{'N':10, 'B':4},{'N':15, 'B':4}],
    'orderedWorkers': [{'N':5, 'B':18},{'N':10, 'B':18}, {'N':15, 'B':18}],
    'decoupledCounterexample': [{'N':5, 'B':4},{'N':10, 'B':4}, {'N':15, 'B':4}]
}

titles = {
    'constantCosts': '(a) Constant unitary costs',
    'orderedWorkers': '(b) Ordered workers',
    'decoupledCounterexample': '(c) Specialist domain'
}

#reduction in reward
reward = []
for i, exp in enumerate(dictExp.keys()):
    if exp == 'orderedWorkers':
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='reward', yerr='std_reward', normalize=True, cost=False)
    else:
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='reward', yerr='std_reward', normalize=True, cost=True)
    x = (df1_height.loc[:,'hawkins']-df1_height.loc[:,'MWRMAB_bt'])/df1_height.loc[:,'hawkins']
    reward.append(x.values)
np.mean(reward)

#increase in fairness
fairness = []
for i, exp in enumerate(dictExp.keys()):
    if exp == 'orderedWorkers':
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='fraction_fair', yerr='fraction_fair_std', normalize=True, cost=False)
    else:
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='fraction_fair', yerr='fraction_fair_std', normalize=True, cost=True)
    x = (df1_height.loc[:,'MWRMAB_bt']-df1_height.loc[:,'hawkins'])/df1_height.loc[:,'hawkins']
    fairness.append(x.values)
np.mean(np.ma.masked_invalid(fairness))

# PLOT REWARD AND FAIR FRACTION
minor_yticks = np.arange(0, 1, 0.25)
fig, ax = plt.subplots(figsize=(10,3.5),ncols=3, nrows=2)
for i, exp in enumerate(dictExp.keys()):
    if exp == 'orderedWorkers':
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='reward', yerr='std_reward', normalize=True, cost=False)
        df2_height, df2_yerr = dataPlot(dictExp[exp], exp, height='fraction_fair', yerr='fraction_fair_std', normalize=False, cost=False)
    else:
        df1_height, df1_yerr = dataPlot(dictExp[exp], exp, height='reward', yerr='std_reward', normalize=True, cost=True)
        df2_height, df2_yerr = dataPlot(dictExp[exp], exp, height='fraction_fair', yerr='fraction_fair_std', normalize=False, cost=True)

    barPlot(df1_height, df1_yerr, ax[(0,i)])
    barPlot(df2_height, df2_yerr, ax[(1,i)])
    if i!=0:
        ax[(0,i)].yaxis.set_ticks([])
        ax[(1,i)].yaxis.set_ticks([])
    ax[(1,i)].set_xlabel(titles[exp], labelpad=15)
    ax[(0,i)].set_xlabel(None)
    ax[(0,i)].set_yticks(minor_yticks, minor=True)
    ax[(1,i)].set_yticks(minor_yticks, minor=True)
    ax[(0,i)].grid(axis='y', which='both',alpha=0.8, zorder=0) 
    ax[(1,i)].grid(axis='y', which='both',alpha=0.8, zorder=0) 

ax[(0,0)].set_ylabel('Mean \n reward per arm')
ax[(1,0)].set_ylabel('Fraction of \n fair allocations')
fig.legend(['CWI+BA','PWI+BA','CWI+GA','Hawkins','OPT','OPT Fair'],loc='upper center', bbox_to_anchor=(0.5, 0.95),ncol=6)
fig.tight_layout()
plt.subplots_adjust(top=0.8)
#plt.save_fig(''')
plt.show()


# HIGHER N
# PLOT REWARD AND FAIR FRACTION
exps = [{'N':5, 'B':4},{'N':50, 'B':40},{'N':100, 'B':40},{'N':150, 'B':40}]

df1_height, df1_yerr = dataPlot(exps, 'constantCosts', height='reward', yerr='std_reward', normalize=True, cost=True)
df2_height, df2_yerr = dataPlot(exps, 'constantCosts', height='fraction_fair', yerr='fraction_fair_std', normalize=False, cost=True)

df1_height, df1_yerr = df1_height.iloc[1:], df1_yerr.iloc[1:]
df2_height, df2_yerr = df2_height.iloc[1:], df2_yerr.iloc[1:]

df1_height, df1_yerr = df1_height.iloc[[2,0,1],:], df1_yerr.iloc[[2,0,1],:]
df2_height, df2_yerr = df2_height.iloc[[2,0,1],:], df2_yerr.iloc[[2,0,1],:]

minor_yticks = np.arange(0, 1, 0.25)
fig, ax = plt.subplots(figsize=(6,1.8),ncols=2)
barPlot(df1_height, df1_yerr, ax[0])
barPlot(df2_height, df2_yerr, ax[1])

ax[0].set_xlabel(None)
ax[1].set_xlabel(None)
ax[0].set_yticks(minor_yticks, minor=True)
ax[1].set_yticks(minor_yticks, minor=True)
ax[0].grid(axis='y', which='both',alpha=0.8, zorder=0) 
ax[1].grid(axis='y', which='both',alpha=0.8, zorder=0) 

ax[0].set_ylabel('Mean \n reward per arm')
ax[1].set_ylabel('Fraction of \n fair allocations')
fig.legend(['CWI+BA','PWI+BA','CWI+GA','Hawkins'],loc='upper center', bbox_to_anchor=(0.5, 0.95),ncol=6)
fig.tight_layout()
plt.subplots_adjust(top=0.7)
plt.show()

# PLOT RUNTIMES - BAR GRAPHS
fig, ax = plt.subplots(figsize=(8,2),ncols=3)
exp = list(dictExp.keys())[2]
for i, exp in enumerate(dictExp.keys()):
    x = dictExp[exp][i]
    df_height, df_yerr = getRuntimeBarData(exp,dictExp)
    opt = ['OPT','OPT_fair'] if 'OPT' in df_height.columns else []
    df_height = df_height.loc[:,['MWRMAB_adj_bt','MWRMAB_bt','hawkins']+opt]
    df_yerr = df_yerr.loc[:,['MWRMAB_adj_bt','MWRMAB_bt','hawkins']+opt]
    colors = ['tab:blue','tab:orange','white','dimgray','darkgray']
    #df_height = np.log(1+df_height)
    #df_yerr = np.log(1+df_yerr)
    barPlot(df_height, df_yerr, ax[i], lim=False, colors=colors)
    size = 0.5 if exp=='decoupledCounterexample' else 0.1
    minor_yticks = np.arange(0, df_height.max().max() + df_yerr.max().max() + size, size)
    ax[i].set_yticks(minor_yticks, minor=True)
    ax[i].grid(axis='y', which='both',alpha=0.8, zorder=0) 
    ax[i].set_xlabel(titles[exp], labelpad=15)

ax[0].set_ylabel('Average runtime \n (seconds)')
fig.legend(['CWI+BA','PWI+BA','Hawkins','OPT','OPT Fair'],loc='upper center', bbox_to_anchor=(0.5, 1),ncol=6)
fig.tight_layout()
plt.subplots_adjust(top=0.8)
plt.show()

# PLOT RUNTIMES - BAR GRAPHS - HIGHER N
exps = [{'N':5, 'B':4},{'N':50, 'B':40},{'N':100, 'B':40},{'N':150, 'B':40}]

df_height, df_yerr = getRuntimeBarData('constantCosts',{'constantCosts':exps})
df_height, df_yerr = df_height.iloc[1:], df_yerr.iloc[1:]
df_height, df_yerr = df_height.iloc[[2,0,1],:], df_yerr.iloc[[2,0,1],:]


opt = ['OPT','OPT_fair'] if 'OPT' in df_height.columns else []
df_height = df_height.loc[:,['MWRMAB_adj_bt','MWRMAB_bt','hawkins']+opt]
df_yerr = df_yerr.loc[:,['MWRMAB_adj_bt','MWRMAB_bt','hawkins']+opt]
colors = ['tab:blue','tab:orange','white','dimgray','darkgray']

fig, ax = plt.subplots(figsize=(3.5,1.5))
barPlot(df_height, df_yerr, ax, lim=False, colors=colors)
size = 1
minor_yticks = np.arange(0, df_height.max().max() + df_yerr.max().max() + size, size)
ax.set_yticks(minor_yticks, minor=True)
ax.grid(axis='y', which='both',alpha=0.8, zorder=0) 
ax.set_xlabel('', labelpad=15)

ax.set_ylabel('Average runtime \n (seconds)')
fig.legend(['CWI+BA','PWI+BA','Hawkins'],loc='upper center',ncol=3, prop={'size':9})
fig.tight_layout()
plt.subplots_adjust(top=0.7)
plt.show()