import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def getData(data, N, B, minC=None, maxC=None, cost=None):
    if data == 'orderedWorkers':
        M=3
        path_cost = f'_minC{minC}_maxC{maxC}' if minC is not None else f'C{cost}'
    if data == 'constantCosts':
        M=2
        path_cost = f'_C{cost}'
    if data == 'decoupledCounterexample':
        M=2
        path_cost = ''     
    path = f'./logs/meanResults/{data}/{data}_N{N}_B{B}_M{M}{path_cost}.csv'

    data = pd.read_csv(path, index_col=0)
    return data

def dataPlot(dictExp, exp, height, yerr ,normalize=True, cost=True):
    dataEnv = []
    for x in dictExp:
        if cost:
            data = getData(exp, N=x['N'], B=x['B'], cost=1)
        else:
            data = getData(exp, N=x['N'], B=x['B'], minC=1, maxC=10)
        opt = ['OPT','OPT_fair'] if 'OPT' in data.index else []
        data = data.loc[['MWRMAB_bt','MWRMAB_adj_bt','hawkins','S1S2']+opt]
        data['setup'] = f'N:{x["N"]} B:{x["B"]}'

        if normalize:
            data[height] = data[height]/int(x['N'])
            data[yerr] = data[yerr]/int(x['N'])

        dataEnv.append(data[[height, yerr, 'setup']])

    df_height = pd.concat(dataEnv)[[height,'setup']].reset_index().pivot_table(index='setup',columns='index',values=height)
    df_yerr = pd.concat(dataEnv)[[yerr,'setup']].reset_index().pivot_table(index='setup',columns='index',values=yerr)

    df_height = df_height[['MWRMAB_adj_bt','MWRMAB_bt','S1S2','hawkins', 'OPT','OPT_fair']]
    df_yerr = df_yerr[['MWRMAB_adj_bt','MWRMAB_bt','S1S2','hawkins', 'OPT','OPT_fair']]

    # order from lowest to highest N
    df_height['N'] = list(map(int,df_height.index.str[2:4]))
    df_height = df_height.sort_values('N', ascending=True)
    df_height.drop('N',axis=1,inplace=True)

    df_yerr['N'] = list(map(int,df_yerr.index.str[2:4]))
    df_yerr = df_yerr.sort_values('N', ascending=True)
    df_yerr.drop('N',axis=1,inplace=True)

    return df_height, df_yerr

def barPlot(data_height, data_yerr, ax, lim=True, colors=['tab:blue','tab:orange','tab:green','white','dimgray','darkgray']):
    data_height.plot(
        kind='bar',
        yerr=data_yerr,
        ax=ax,
        edgecolor='black',
        color=colors,
        zorder=3
    )    
    #ax.set_xlabel(None)
    ax.get_legend().remove() 
    ax.tick_params(axis='x',labelrotation=0)  
    if lim:
        ax.set_ylim(0,1)

import pickle
def getRuntimeDataMeans(x, exp):
    if exp != 'orderedWorkers':
        path = f'./logs/runtimes/{exp}_N{x["N"]}_B{x["B"]}_M2_C1.pkl'
    else:
        path = f'./logs/runtimes/{exp}_N{x["N"]}_B{x["B"]}_M3_minC1_maxC10.pkl'
    with open(path, 'rb') as f:
        loaded_dict = pickle.load(f)
    means = [pd.DataFrame(loaded_dict[k]).mean() for k in loaded_dict.keys()]
    means = pd.concat(means,axis=1).T
    means['index'] = list(loaded_dict.keys())
    means.set_index('index',inplace=True)

    return means

def getRuntimeDataStds(x, exp):
    if exp != 'orderedWorkers':
        path = f'./logs/runtimes/{exp}_N{x["N"]}_B{x["B"]}_M2_C1.pkl'
    else:
        path = f'./logs/runtimes/{exp}_N{x["N"]}_B{x["B"]}_M3_minC1_maxC10.pkl'
    with open(path, 'rb') as f:
        loaded_dict = pickle.load(f)
    stds = [pd.DataFrame(loaded_dict[k]).std() for k in loaded_dict.keys()]
    stds = pd.concat(stds,axis=1).T
    stds['index'] = list(loaded_dict.keys())
    stds.set_index('index',inplace=True)

    return stds

def getRuntimeBarData(exp,dictExp):
    means = []
    stds = []
    for j in range(len(dictExp[exp])):
        x = dictExp[exp][j]
        data_mean = getRuntimeDataMeans(x, exp).iloc[:,[-1]]
        data_std = getRuntimeDataStds(x, exp).iloc[:,[-1]]
        data_mean.columns = ['runtime']
        data_std.columns = ['runtime_std']
        data_mean['setup'] = f'N:{x["N"]} B:{x["B"]}'
        data_std['setup'] = f'N:{x["N"]} B:{x["B"]}'
        means.append(data_mean)
        stds.append(data_std)
    df_height = pd.concat(means)
    df_yerr = pd.concat(stds)
    df_height = df_height.reset_index().pivot_table(index='setup',columns='index',values='runtime')
    df_yerr= df_yerr.reset_index().pivot_table(index='setup',columns='index',values='runtime_std')
    
    # order from lowest to highest N
    df_height['N'] = list(map(int,df_height.index.str[2:4]))
    df_height = df_height.sort_values('N', ascending=True)
    df_height.drop('N',axis=1,inplace=True)

    df_yerr['N'] = list(map(int,df_yerr.index.str[2:4]))
    df_yerr = df_yerr.sort_values('N', ascending=True)
    df_yerr.drop('N',axis=1,inplace=True)

    return df_height, df_yerr