from matplotlib import pyplot as plt
import numpy as np
import csv

def main():

    # Exp1 (performance monotonicity)

    # Agent A1
    # ba_sv_tr_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/sv_tr_1.csv")))])
    # ba_ap_tr_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/ap_tr_1.csv")))])
    # ba_mc_tr_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/mc_tr_1.csv")))])
    # ba_mr_tr_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/mr_tr_1.csv")))])

    # alpha_prime = [.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0]

    # # fix alpha = 0.4
    # i = 4
    # fig = plt.figure()
    # plt.plot(alpha_prime, ba_sv_tr_1[i, :], color = 'purple', linestyle = 'dashed', linewidth =2,
    #         marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    # plt.plot(alpha_prime, ba_ap_tr_1[i, :], color = 'orange', linestyle = 'dashed', linewidth =2,
    #         marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    # plt.plot(alpha_prime, ba_mr_tr_1[i, :], color = 'red', linestyle = 'dashed', linewidth =2,
    #         marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    # plt.plot(alpha_prime, ba_mc_tr_1[i, :], color = 'green', linestyle = 'dashed', linewidth =2,
    #         marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')

    # plt.xlabel('\u03B1' + '\'', fontsize=30)
    # plt.ylabel('Blame A1', fontsize=30)
    # plt.tick_params(labelsize=20)
    # plt.tight_layout()
    # # fig.savefig('gridworld/plots/exp1/gridworld-per_mono.png')
    # fig.savefig('gridworld/plots/exp1/gridworld-per_mono1.pdf')
    # plt.close(fig)

    # Agent A2
    ba_sv_tr_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/sv_tr_2.csv")))])
    ba_ap_tr_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/ap_tr_2.csv")))])
    ba_mc_tr_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/mc_tr_2.csv")))])
    ba_mr_tr_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp1/mr_tr_2.csv")))])

    alpha_prime = [.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0]

    # fix alpha = 0.4
    i = 4
    fig = plt.figure()
    plt.plot(alpha_prime, ba_sv_tr_2[i, :], color = 'purple', linestyle = 'dashed', linewidth =2,
            marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    plt.plot(alpha_prime, ba_ap_tr_2[i, :], color = 'orange', linestyle = 'dashed', linewidth =2,
            marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    plt.plot(alpha_prime, ba_mr_tr_2[i, :], color = 'red', linestyle = 'dashed', linewidth =2,
            marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    plt.plot(alpha_prime, ba_mc_tr_2[i, :], color = 'green', linestyle = 'dashed', linewidth =2,
            marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')

    plt.xlabel('\u03B1' + '\'', fontsize=30)
    plt.ylabel('Blame A2', fontsize=30)
    axes = plt.gca()
    axes.set_xlim(right=1.08)
    axes.set_ylim(top=0.16)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('gridworld/plots/exp1/gridworld-per_mono.png')
    fig.savefig('gridworld/plots/exp1/gridworld-per_mono.pdf')
    plt.close(fig)

    # Exp3 part a (different SV approaches)
    ba_sv_tr_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_tr_1.csv")))])
    ba_sv_tr_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_tr_2.csv")))])
    ba_sv_es_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_es_1.csv")))])
    ba_sv_es_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_es_2.csv")))])
    ba_sv_va_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_va_1.csv")))])
    ba_sv_va_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_va_2.csv")))])
    ba_sv_bc_1 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_bc_1.csv")))])
    ba_sv_bc_2 = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3a/sv_bc_2.csv")))])

    # Mean and standard deviation over different seeds
    n_seeds, n_est = ba_sv_tr_1.shape
    mean_ba_sv_es_1 = [0] * n_est
    mean_ba_sv_es_2 = [0] * n_est
    mean_ba_sv_va_1 = [0] * n_est
    mean_ba_sv_va_2 = [0] * n_est
    mean_ba_sv_bc_1 = [0] * n_est
    mean_ba_sv_bc_2 = [0] * n_est
    std_ba_sv_es_1 = [0] * n_est
    std_ba_sv_es_2 = [0] * n_est
    std_ba_sv_va_1 = [0] * n_est
    std_ba_sv_va_2 = [0] * n_est
    std_ba_sv_bc_1 = [0] * n_est
    std_ba_sv_bc_2 = [0] * n_est

    for j in range(n_est):
        mean_ba_sv_es_1[j] = np.mean(ba_sv_es_1[:, j])
        mean_ba_sv_es_2[j] = np.mean(ba_sv_es_2[:, j])
        mean_ba_sv_va_1[j] = np.mean(ba_sv_va_1[:, j])
        mean_ba_sv_va_2[j] = np.mean(ba_sv_va_2[:, j])
        mean_ba_sv_bc_1[j] = np.mean(ba_sv_bc_1[:, j])
        mean_ba_sv_bc_2[j] = np.mean(ba_sv_bc_2[:, j])
        std_ba_sv_es_1[j] = np.std(ba_sv_es_1[:, j], ddof=1)
        std_ba_sv_es_2[j] = np.std(ba_sv_es_2[:, j], ddof=1)
        std_ba_sv_va_1[j] = np.std(ba_sv_va_1[:, j], ddof=1)
        std_ba_sv_va_2[j] = np.std(ba_sv_va_2[:, j], ddof=1)
        std_ba_sv_bc_1[j] = np.std(ba_sv_bc_1[:, j], ddof=1)
        std_ba_sv_bc_2[j] = np.std(ba_sv_bc_2[:, j], ddof=1)

    # estimation errors
    labels = [".05", ".1", ".15", ".2"]
    X = np.arange(len(labels))

    # Agent A1
    # fig = plt.figure()
    # plt.bar(X, mean_ba_sv_es_1, width = .3, color = 'green', label = 'Point-Estimate', yerr=std_ba_sv_es_1, capsize=5)
    # plt.bar(X + .3, mean_ba_sv_va_1, width = .3, color = 'orange', label = 'Valid', yerr=std_ba_sv_va_1, capsize=5)
    # plt.bar(X + .6, mean_ba_sv_bc_1, width = .3, color = 'red', label = 'Consistent', yerr=std_ba_sv_bc_1, capsize=5)
    # plt.axhline(y=ba_sv_tr_1[0, 0],linewidth=1, color='purple', linestyle = '--', label = 'Targeted Value')
    
    # plt.xticks([.3, 1.3, 2.3, 3.3], labels)
    # plt.xlabel('$\epsilon_{max}$', fontsize=30)
    # plt.ylabel('Blame A1', fontsize=30)
    # plt.tick_params(labelsize=20)
    # plt.tight_layout()
    # # fig.savefig('gridworld/plots/exp3a/gridworld-different_sv_methods.png')
    # fig.savefig('gridworld/plots/exp3a/gridworld-different_sv_methods1.pdf')
    # plt.close(fig)

    # Agent A2
    fig = plt.figure()
    plt.bar(X, mean_ba_sv_es_2, width = .3, color = 'green', label = 'Point-Estimate', yerr=std_ba_sv_es_2, capsize=5)
    plt.bar(X + .3, mean_ba_sv_va_2, width = .3, color = 'orange', label = 'Valid', yerr=std_ba_sv_va_2, capsize=5)
    plt.bar(X + .6, mean_ba_sv_bc_2, width = .3, color = 'red', label = 'Consistent', yerr=std_ba_sv_bc_2, capsize=5)
    plt.axhline(y=ba_sv_tr_2[0, 0],linewidth=1, color='purple', linestyle = '--', label = 'Targeted Value')

    plt.xticks([.3, 1.3, 2.3, 3.3], labels)
    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('Blame A2', fontsize=30)
    axes = plt.gca()
    axes.set_ylim(top=0.07)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('gridworld/plots/exp3a/ag2.png')
    fig.savefig('gridworld/plots/exp3a/gridworld-different_sv_methods.pdf')
    plt.close(fig)

    # Exp3 part b and c (Distance and Total Blame) 
    dist_sv = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3b/sv.csv")))])
    dist_ap = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3b/ap.csv")))])
    dist_mc = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3b/mc.csv")))])
    dist_mr = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3b/mr.csv")))])

    tot_ba_sv = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3c/sv.csv")))])
    tot_ba_ap = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3c/ap.csv")))])
    tot_ba_mc = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3c/mc.csv")))])
    tot_ba_mr = np.array([[float(x) for x in y] for y in list(csv.reader(open("gridworld/data/exp3c/mr.csv")))])

    # Mean and standard deviation over different seeds
    n_seeds, n_est = dist_sv.shape
    mean_dist_sv = [0] * n_est
    mean_dist_ap = [0] * n_est
    mean_dist_mc = [0] * n_est
    mean_dist_mr = [0] * n_est
    std_dist_sv = [0] * n_est
    std_dist_ap = [0] * n_est
    std_dist_mc = [0] * n_est
    std_dist_mr = [0] * n_est

    mean_tot_ba_sv = [0] * n_est
    mean_tot_ba_ap = [0] * n_est
    mean_tot_ba_mc = [0] * n_est
    mean_tot_ba_mr = [0] * n_est
    std_tot_ba_sv = [0] * n_est
    std_tot_ba_ap = [0] * n_est
    std_tot_ba_mc = [0] * n_est
    std_tot_ba_mr = [0] * n_est

    for j in range(n_est):
        mean_dist_sv[j] = np.mean(dist_sv[:, j])
        mean_dist_ap[j] = np.mean(dist_ap[:, j])
        mean_dist_mc[j] = np.mean(dist_mc[:, j])
        mean_dist_mr[j] = np.mean(dist_mr[:, j])
        std_dist_sv[j] = np.std(dist_sv[:, j], ddof=1)
        std_dist_ap[j] = np.std(dist_ap[:, j], ddof=1)
        std_dist_mc[j] = np.std(dist_mc[:, j], ddof=1)
        std_dist_mr[j] = np.std(dist_mr[:, j], ddof=1)

        mean_tot_ba_sv[j] = np.mean(tot_ba_sv[:, j])
        mean_tot_ba_ap[j] = np.mean(tot_ba_ap[:, j])
        mean_tot_ba_mc[j] = np.mean(tot_ba_mc[:, j])
        mean_tot_ba_mr[j] = np.mean(tot_ba_mr[:, j])
        std_tot_ba_sv[j] = np.std(tot_ba_sv[:, j], ddof=1)
        std_tot_ba_ap[j] = np.std(tot_ba_ap[:, j], ddof=1)
        std_tot_ba_mc[j] = np.std(tot_ba_mc[:, j], ddof=1)
        std_tot_ba_mr[j] = np.std(tot_ba_mr[:, j], ddof=1)

    # estimation errors
    est_err_2_lst = [0, .05, .1, .15, .2, .25, .3, .35, .4]

    # Distance    
    fig = plt.figure()
    plt.plot(est_err_2_lst, mean_dist_sv, color = 'purple', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    plt.errorbar(est_err_2_lst, mean_dist_sv, yerr=std_dist_sv, capsize=5, fmt='none', ecolor='purple')
    plt.plot(est_err_2_lst, mean_dist_ap, color = 'orange', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    plt.errorbar(est_err_2_lst, mean_dist_ap, yerr=std_dist_ap, capsize=5, fmt='none', ecolor='orange')
    plt.plot(est_err_2_lst, mean_dist_mr, color = 'red', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    plt.errorbar(est_err_2_lst, mean_dist_mr, yerr=std_dist_mr, capsize=5, fmt='none', ecolor='red')
    plt.plot(est_err_2_lst, mean_dist_mc, color = 'green', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')
    plt.errorbar(est_err_2_lst, mean_dist_mc, yerr=std_dist_mc, capsize=5, fmt='none', ecolor='green')

    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('L1 Distance', fontsize=30)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('gridworld/plots/exp3b/gridworld-distance.png')
    fig.savefig('gridworld/plots/exp3b/gridworld-distance.pdf')
    plt.close(fig)

    # Total Blame
    fig = plt.figure()
    plt.plot(est_err_2_lst, mean_tot_ba_sv, color = 'purple', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'SV')
    plt.errorbar(est_err_2_lst, mean_tot_ba_sv, yerr=std_tot_ba_sv, capsize=5, fmt='none', ecolor='purple')
    plt.plot(est_err_2_lst, mean_tot_ba_ap, color = 'orange', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'AP')
    plt.errorbar(est_err_2_lst, mean_tot_ba_ap, yerr=std_tot_ba_ap, capsize=5, fmt='none', ecolor='orange')
    plt.plot(est_err_2_lst, mean_tot_ba_mr, color = 'red', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MER')
    plt.errorbar(est_err_2_lst, mean_tot_ba_mr, yerr=std_tot_ba_mr, capsize=5, fmt='none', ecolor='red')
    plt.plot(est_err_2_lst, mean_tot_ba_mc, color = 'green', linestyle = 'dashed', linewidth =2,
        marker = 'o', markerfacecolor = 'blue', markersize = 6, label = 'MC')
    plt.errorbar(est_err_2_lst, mean_tot_ba_mc, yerr=std_tot_ba_mc, capsize=5, fmt='none', ecolor='green')

    plt.xlabel('$\epsilon_{max}$', fontsize=30)
    plt.ylabel('Total Blame', fontsize=30)
    axes = plt.gca()
    axes.set_ylim(bottom=0)
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    # fig.savefig('gridworld/plots/exp3c/gridworld-total_blame.png')
    fig.savefig('gridworld/plots/exp3c/gridworld-total_blame.pdf')
    plt.close(fig)


if __name__ == '__main__':
    # Local:
    # python example.py
    main()