import os
import numpy as np 
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

#meanReward
def meanReward(exp, rewards, folder_name):
    means = []
    std = []
    for algo in rewards.keys():
        means.append(rewards[algo].mean(axis=1).mean())
        std.append(rewards[algo].mean(axis=1).std())


    fig, ax2 = plt.subplots(figsize=(3.5,3))
    ax2.bar(x=list(rewards.keys()), height=means, yerr=std)
    ax2.set_ylabel(f'Average Reward')
    if exp.data == 'constantCosts':
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data} of {exp.env.cost}'
    else:
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data}'
    fig.suptitle(title)
    plt.xticks(rotation=45)
    fig.tight_layout()
    plt.show()
    plt.savefig(f'{folder_name}/meanReward.png')

#fractionFair
def fractionFair(exp, usedBudget, folder_name):
    fraction = []
    fraction_std = []
    for algo in usedBudget.keys():
        x = ((usedBudget[algo][:,:,:].max(axis=2) - usedBudget[algo][:,:,:].min(axis=2))<=exp.env.C.max()).mean(axis=1)
        fraction.append(x.mean())
        fraction_std.append(x.std())

    fig, ax2 = plt.subplots(figsize=(3.5,3))
    ax2.bar(x=list(usedBudget.keys()), height=fraction, yerr = fraction_std)
    ax2.set_ylabel(f'Fraction Fair')
    if exp.data == 'constantCosts':
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data} of {exp.env.cost}'
    else:
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data}'
    fig.suptitle(title)
    plt.ylim(0,1)
    plt.xticks(rotation=45)
    fig.tight_layout()
    plt.savefig(f'{folder_name}/fractionFair.png')

#maxDiffBudget
def maxDiffBudget(exp, usedBudget, folder_name):
    means = []
    std = []
    for algo in usedBudget.keys():
        means.append(np.absolute(usedBudget[algo][:,:,0]-usedBudget[algo][:,:,1]).mean(axis=0).mean())
        std.append(np.absolute(usedBudget[algo][:,:,0]-usedBudget[algo][:,:,1]).mean(axis=0).std())

    fig, ax = plt.subplots(figsize=(3.5,3))
    ax.bar(x=list(usedBudget.keys()), height=means, yerr=std)
    ax.set_ylabel(f'Max Cost Difference in \n Spent Budget')
    if exp.data == 'constantCosts':
        ax.axhline(exp.env.C.max(), color='r', linestyle='--', label='Max Cost')
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data} of {exp.env.cost}'
        plt.legend(loc=1)
    else:
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data}'
    fig.suptitle(title)
    plt.xticks(rotation=45)
    fig.tight_layout()
    plt.savefig(f'{folder_name}/maxDiffBudget.png')

#usedBudget
def meanUsedBudget(exp, usedBudget, folder_name):
    mean = []
    mean_std = []
    for algo in usedBudget.keys():
        x = usedBudget[algo][:,:,:].mean(axis=1)
        mean.append(x.mean(axis=0))
        mean_std.append(x.std(axis=0))

    mean = pd.DataFrame(mean)
    mean = mean.rename(index = {i:list(usedBudget.keys())[i] for i in range(len(list(usedBudget.keys())))})

    mean_std = pd.DataFrame(mean_std)
    mean_std = mean_std.rename(index = {i:list(usedBudget.keys())[i] for i in range(len(list(usedBudget.keys())))})

    fig, ax = plt.subplots(figsize=(8,4))
    mean.plot.bar(ax=ax,yerr=mean_std)
    ax.legend(['Action type 1', 'Action type 2'],fontsize=8)
    ax.set_ylabel('Mean Used Budget')
    if exp.data == 'constantCosts':
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data} of {exp.env.cost}'
    else:
        title = f'N={exp.env.N} B={exp.env.B} M={exp.env.M} \n {exp.data}'
    fig.suptitle(title)
    plt.xticks(rotation=0)
    plt.savefig(f'{folder_name}/usedBudget.png')
    
def savePlots(exp, rewards, usedBudget):
    # Create plots folder if is does not exist
    if not os.path.exists('./plots'):
        os.mkdir('./plots')
    
    # Create folder for experiment
    if exp.data == 'constantCosts':
        folder_name = f'./plots/{exp.data}{exp.env.cost}_N{exp.env.N}_B{exp.env.B}_M{exp.env.M}'
    else:
        folder_name = f'./plots/{exp.data}_N{exp.env.N}_B{exp.env.B}_M{exp.env.M}'

    os.mkdir(folder_name)

    # Run and save plots
    meanReward(exp, rewards, folder_name)
    maxDiffBudget(exp, usedBudget, folder_name)
    fractionFair(exp, usedBudget, folder_name)
    meanUsedBudget(exp, usedBudget, folder_name)

    # Save execution time
    items = exp.time.items()
    df = pd.DataFrame({'seconds': [i[1] for i in items]},index=[i[0] for i in items])
    df.to_csv(f'{folder_name}/execution_time.csv',sep=' ', index=True, header=True)


