from matplotlib import pyplot as plt
import numpy as np
import csv


def main():

    # Exp2 (coordination)
    # number of agents
    n = 4
    # least number of agents to satisfy constraint
    cases = [1, 2, 3, 4]
    
    N = [i for i in range(n)]
    
    ba_sv = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp2/sv.csv')))])
    ba_ap = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp2/ap.csv')))])
    ba_bi = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp2/bi.csv')))])
    ba_mc = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp2/mc.csv')))])
    ba_mr = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp2/mr.csv')))])

    tot_blame_sv = [0] * len(cases)
    tot_blame_ap = [0] * len(cases)
    tot_blame_bi = [0] * len(cases)
    tot_blame_mc = [0] * len(cases)
    tot_blame_mr = [0] * len(cases)
    
    fig, ax = plt.subplots()
    for c in range(len(cases)):
        tot_blame_sv[c] = np.sum(ba_sv[c])
        tot_blame_ap[c] = np.sum(ba_ap[c])
        tot_blame_bi[c] = np.sum(ba_bi[c])
        tot_blame_mc[c] = np.sum(ba_mc[c])
        tot_blame_mr[c] = np.sum(ba_mr[c])

    labels = ["1", "2", "3", "4"]

    X = 2 * np.arange(len(labels))

    plt.bar(X, tot_blame_sv, width = .35, color = 'purple', label = 'SV')
    plt.bar(X + .35, tot_blame_ap, width = .35, color = 'orange', label = 'AP')
    plt.bar(X + .7, tot_blame_bi, width = .35, color = 'blue', label = 'BI')
    plt.bar(X + 1.05, tot_blame_mc, width = .35, color = 'green', label = 'MC')
    plt.bar(X + 1.4, tot_blame_mr, width = .35, color = 'red', label = 'MER')

    plt.xticks([.7, 2.7, 4.7, 6.7], labels)
    plt.xlabel('m', fontsize=30)
    plt.ylabel('Total Blame', fontsize=30)
    axes = plt.gca()
    axes.set_ylim(top=21)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('graph/plots/exp2/graph-coordination.png')
    fig.savefig('graph/plots/exp2/graph-coordination.pdf')
    legend = plt.legend(loc='upper center',
            ncol=3, fancybox=True, shadow=True, fontsize='xx-large')
    fig.canvas.draw()
    legend_bbox = legend.get_tightbbox(fig.canvas.get_renderer())
    legend_bbox = legend_bbox.transformed(fig.dpi_scale_trans.inverted())
    legend_fig, legend_ax = plt.subplots(figsize=(legend_bbox.width, legend_bbox.height))
    legend_squared = legend_ax.legend(
        *ax.get_legend_handles_labels(), 
        bbox_to_anchor=(0, 0, 1, 1),
        bbox_transform=legend_fig.transFigure,
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=3,
        fontsize='xx-large',
    )
    legend_ax.axis('off')
    legend_fig.savefig('graph/plots/exp2/graph-coordination_lgd.pdf',  
                        bbox_inches='tight',
                        bbox_extra_artists=[legend_squared]
    )
    plt.close(fig)

    # Exp3 part a (different SV approaches)
    n = 4

    # agent = 0
    ba_sv_tr = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp3a/sv_tr.csv')))])
    ba_sv_es = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp3a/sv_es.csv')))])
    ba_sv_va = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp3a/sv_va.csv')))])
    ba_sv_bc = np.array([[float(x) for x in y] for y in list(csv.reader(open('graph/data/exp3a/sv_bc.csv')))])

    # Mean and standard deviation over different seeds
    n_seeds, n_est = ba_sv_tr.shape
    mean_ba_sv_tr = [0] * n_est
    mean_ba_sv_es = [0] * n_est
    mean_ba_sv_va = [0] * n_est
    mean_ba_sv_bc = [0] * n_est
    std_ba_sv_tr = [0] * n_est
    std_ba_sv_es = [0] * n_est
    std_ba_sv_va = [0] * n_est
    std_ba_sv_bc = [0] * n_est
    
    for j in range(n_est):
        mean_ba_sv_tr[j] = np.mean(ba_sv_tr[:, j])
        mean_ba_sv_es[j] = np.mean(ba_sv_es[:, j])
        mean_ba_sv_va[j] = np.mean(ba_sv_va[:, j])
        mean_ba_sv_bc[j] = np.mean(ba_sv_bc[:, j])
        std_ba_sv_tr[j] = np.std(ba_sv_tr[:, j], ddof=1)
        std_ba_sv_es[j] = np.std(ba_sv_es[:, j], ddof=1)
        std_ba_sv_va[j] = np.std(ba_sv_va[:, j], ddof=1)
        std_ba_sv_bc[j] = np.std(ba_sv_bc[:, j], ddof=1)

    # estimation errors
    labels = [".01", ".05", ".1"]
    X = np.arange(len(labels))

    fig, ax = plt.subplots()
    plt.bar(X, mean_ba_sv_es, width = .3, color = 'green', label = 'Point-Estimate', yerr=std_ba_sv_es, capsize=5)
    plt.bar(X + .3, mean_ba_sv_va, width = .3, color = 'orange', label = 'Valid', yerr=std_ba_sv_va, capsize=5)
    plt.bar(X + .6, mean_ba_sv_bc, width = .3, color = 'red', label = 'Consistent', yerr=std_ba_sv_bc, capsize=5)
    plt.axhline(y=ba_sv_tr[0, 0], linewidth=1, color='purple', linestyle = '--', label = 'Targeted Value')
    
    plt.xticks([.3, 1.3, 2.3], labels)
    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('Blame Agent 1', fontsize=30)
    axes = plt.gca()
    axes.set_ylim(top=0.7)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('graph/plots/exp3a/graph-diferent_sv_methods.png')
    fig.savefig('graph/plots/exp3a/graph-diferent_sv_methods.pdf')
    legend = plt.legend(loc='upper center',
            ncol=2, fancybox=True, shadow=True, fontsize='xx-large')
    fig.canvas.draw()
    legend_bbox = legend.get_tightbbox(fig.canvas.get_renderer())
    legend_bbox = legend_bbox.transformed(fig.dpi_scale_trans.inverted())
    legend_fig, legend_ax = plt.subplots(figsize=(legend_bbox.width, legend_bbox.height))
    legend_squared = legend_ax.legend(
        *ax.get_legend_handles_labels(), 
        bbox_to_anchor=(0, 0, 1, 1),
        bbox_transform=legend_fig.transFigure,
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=2,
        fontsize='xx-large',
    )
    legend_ax.axis('off')
    legend_fig.savefig('graph/plots/exp3a/graph-diferent_sv_methods_lgd.pdf',  
                        bbox_inches='tight',
                        bbox_extra_artists=[legend_squared]
    )
    plt.close(fig)

    # Exp3 part b and c (Distance and Total Blame)
    dist_sv = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3b/sv.csv")))])
    dist_ap = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3b/ap.csv")))])
    dist_bi = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3b/bi.csv")))])
    dist_mc = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3b/mc.csv")))])
    dist_mr = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3b/mr.csv")))])

    tot_ba_sv = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3c/sv.csv")))])
    tot_ba_ap = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3c/ap.csv")))])
    tot_ba_bi = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3c/bi.csv")))])
    tot_ba_mc = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3c/mc.csv")))])
    tot_ba_mr = np.array([[float(x) for x in y] for y in list(csv.reader(open("graph/data/exp3c/mr.csv")))])

    # Mean and standard deviation over different seeds
    n_seeds, n_est = dist_sv.shape
    mean_dist_sv = [0] * n_est
    mean_dist_ap = [0] * n_est
    mean_dist_bi = [0] * n_est
    mean_dist_mc = [0] * n_est
    mean_dist_mr = [0] * n_est
    std_dist_sv = [0] * n_est
    std_dist_ap = [0] * n_est
    std_dist_bi = [0] * n_est
    std_dist_mc = [0] * n_est
    std_dist_mr = [0] * n_est

    mean_tot_ba_sv = [0] * n_est
    mean_tot_ba_ap = [0] * n_est
    mean_tot_ba_bi = [0] * n_est
    mean_tot_ba_mc = [0] * n_est
    mean_tot_ba_mr = [0] * n_est
    std_tot_ba_sv = [0] * n_est
    std_tot_ba_ap = [0] * n_est
    std_tot_ba_bi = [0] * n_est
    std_tot_ba_mc = [0] * n_est
    std_tot_ba_mr = [0] * n_est


    for j in range(n_est):
        mean_dist_sv[j] = np.mean(dist_sv[:, j])
        mean_dist_ap[j] = np.mean(dist_ap[:, j])
        mean_dist_bi[j] = np.mean(dist_bi[:, j])
        mean_dist_mc[j] = np.mean(dist_mc[:, j])
        mean_dist_mr[j] = np.mean(dist_mr[:, j])
        std_dist_sv[j] = np.std(dist_sv[:, j], ddof=1)
        std_dist_ap[j] = np.std(dist_ap[:, j], ddof=1)
        std_dist_bi[j] = np.std(dist_bi[:, j], ddof=1)
        std_dist_mc[j] = np.std(dist_mc[:, j], ddof=1)
        std_dist_mr[j] = np.std(dist_mr[:, j], ddof=1)

        mean_tot_ba_sv[j] = np.mean(tot_ba_sv[:, j])
        mean_tot_ba_ap[j] = np.mean(tot_ba_ap[:, j])
        mean_tot_ba_bi[j] = np.mean(tot_ba_bi[:, j])
        mean_tot_ba_mc[j] = np.mean(tot_ba_mc[:, j])
        mean_tot_ba_mr[j] = np.mean(tot_ba_mr[:, j])
        std_tot_ba_sv[j] = np.std(tot_ba_sv[:, j], ddof=1)
        std_tot_ba_ap[j] = np.std(tot_ba_ap[:, j], ddof=1)
        std_tot_ba_bi[j] = np.std(tot_ba_bi[:, j], ddof=1)
        std_tot_ba_mc[j] = np.std(tot_ba_mc[:, j], ddof=1)
        std_tot_ba_mr[j] = np.std(tot_ba_mr[:, j], ddof=1)

    # estimation errors, always begin with zero
    est_err_lst = [0, .025, .05, .075, .1, .125, .15, .175, .2]
        
    # Distance
    fig, ax = plt.subplots()
    plt.plot(est_err_lst, mean_dist_sv, color = 'purple', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    plt.errorbar(est_err_lst, mean_dist_sv, yerr=std_dist_sv, capsize=5, fmt='none', ecolor='purple')
    plt.plot(est_err_lst, mean_dist_ap, color = 'orange', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    plt.errorbar(est_err_lst, mean_dist_ap, yerr=std_dist_ap, capsize=5, fmt='none', ecolor='orange')
    plt.plot(est_err_lst, mean_dist_bi, color = 'blue', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'BI')
    plt.errorbar(est_err_lst, mean_dist_bi, yerr=std_dist_bi, capsize=5, fmt='none', ecolor='blue')
    plt.plot(est_err_lst, mean_dist_mc, color = 'green', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')
    plt.errorbar(est_err_lst, mean_dist_mc, yerr=std_dist_mc, capsize=5, fmt='none', ecolor='green')
    plt.plot(est_err_lst, mean_dist_mr, color = 'red', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    plt.errorbar(est_err_lst, mean_dist_mr, yerr=std_dist_mr, capsize=5, fmt='none', ecolor='red')

    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('L1 Distance', fontsize=30)
    axes = plt.gca()
    axes.set_xlim(right=0.215)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('graph/plots/exp3b/graph-distance.png')
    fig.savefig('graph/plots/exp3b/graph-distance.pdf')
    legend = plt.legend(loc='upper center', #bbox_to_anchor=(0.5, 1.2),
            ncol=3, fancybox=True, shadow=True, fontsize='xx-large')
    fig.canvas.draw()
    legend_bbox = legend.get_tightbbox(fig.canvas.get_renderer())
    legend_bbox = legend_bbox.transformed(fig.dpi_scale_trans.inverted())
    legend_fig, legend_ax = plt.subplots(figsize=(legend_bbox.width, legend_bbox.height))
    legend_squared = legend_ax.legend(
        *ax.get_legend_handles_labels(), 
        bbox_to_anchor=(0, 0, 1, 1),
        bbox_transform=legend_fig.transFigure,
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=3,
        fontsize='xx-large',
    )
    legend_ax.axis('off')
    legend_fig.savefig('graph/plots/exp3b/graph-distance_lgd.pdf',  
                        bbox_inches='tight',
                        bbox_extra_artists=[legend_squared]
    )
    plt.close(fig)

    # Total Blame
    fig, ax = plt.subplots()
    plt.plot(est_err_lst, mean_tot_ba_sv, color = 'purple', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    plt.errorbar(est_err_lst, mean_tot_ba_sv, yerr=std_tot_ba_sv, capsize=5, fmt='none', ecolor='purple')
    plt.plot(est_err_lst, mean_tot_ba_ap, color = 'orange', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    plt.errorbar(est_err_lst, mean_tot_ba_ap, yerr=std_tot_ba_ap, capsize=5, fmt='none', ecolor='orange')
    plt.plot(est_err_lst, mean_tot_ba_bi, color = 'blue', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'BI')
    plt.errorbar(est_err_lst, mean_tot_ba_bi, yerr=std_tot_ba_bi, capsize=5, fmt='none', ecolor='blue')
    plt.plot(est_err_lst, mean_tot_ba_mc, color = 'green', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')
    plt.errorbar(est_err_lst, mean_tot_ba_mc, yerr=std_tot_ba_mc, capsize=5, fmt='none', ecolor='green')
    plt.plot(est_err_lst, mean_tot_ba_mr, color = 'red', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    plt.errorbar(est_err_lst, mean_tot_ba_mr, yerr=std_tot_ba_mr, capsize=5, fmt='none', ecolor='red')

    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('Total Blame', fontsize=30)
    axes = plt.gca()
    axes.set_xlim(right=0.215)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('graph/plots/exp3c/graph-total_blame.png')
    fig.savefig('graph/plots/exp3c/graph-total_blame.pdf')
    legend = plt.legend(loc='upper center', #bbox_to_anchor=(0.5, 1.2),
            ncol=3, fancybox=True, shadow=True, fontsize='xx-large')
    fig.canvas.draw()
    legend_bbox = legend.get_tightbbox(fig.canvas.get_renderer())
    legend_bbox = legend_bbox.transformed(fig.dpi_scale_trans.inverted())
    legend_fig, legend_ax = plt.subplots(figsize=(legend_bbox.width, legend_bbox.height))
    legend_squared = legend_ax.legend(
        *ax.get_legend_handles_labels(), 
        bbox_to_anchor=(0, 0, 1, 1),
        bbox_transform=legend_fig.transFigure,
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=3,
        fontsize='xx-large',
    )
    legend_ax.axis('off')
    legend_fig.savefig('graph/plots/exp3c/graph-total_blame_lgd.pdf',  
                        bbox_inches='tight',
                        bbox_extra_artists=[legend_squared]
    )
    plt.close(fig)


if __name__ == '__main__':
    # Local:
    # python example.py
    main()