import pandas as pd
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.markers import MarkerStyle
import matplotlib.ticker as mtick
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

def get_data(file, number, type):
    data = pd.read_csv(file)
    target_list = data[type]
    return target_list.values[number]

def get_mean_and_std(fro, method, file, seed_list, type, dec_rate, number):
    path  = os.path.join(sys.path[-1], fro, method, file, 'dec_'+ str(dec_rate))
    vals = []
    for i in seed_list:
        f = path + '/result_' + str(i) + '.csv'
        vals.append(get_data(f, number, type))
    return np.array(vals).mean(), np.array(vals).std()

def makedir(path):
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        return True
    else:
        return False


def errorbar_plot_rmse(data_name_list, decreasing_rate, metrix, seed_list):

    color_dic = {'conAR':'#DC143C', 'dmfal':'#2ca02c', 'ifc':'#1f77b4', 'ar':'#ff7f0e', 'resgp':'#0033ff'}
    marker_dic = {'conAR':"o", 'dmfal':"s", 'ifc':"^", 'ar':"v", 'resgp':"P"}
    ls_dic = {'conAR':'dashed', 'ifc':'solid', 'dmfal':'solid', 'ar':'solid', 'resgp':'solid'}
    label_dic = {'conAR':'Ours', 'dmfal':'MF-BNN', 'ifc':'IFC-GPT', 'ar':'AR', 'resgp':'resGP'}
    


    # 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
    for file_name in data_name_list:
        type = metrix
        plt.figure(figsize=(7,6), dpi=100)
        
        ifc_list = ["Heat_mfGent_v5", 'plasmonic2_MF']
        if file_name in ifc_list:
            method = ['dmfal', 'resgp', 'ar', 'ifc', 'conAR']
        else:
            method = ['dmfal', 'resgp', 'ar', 'conAR']
        
        orders = [32, 64, 96, 128]
        for i in range(len(method)):
            vals = []
            vars = []
            for j in range(len(orders)): # j = 0, 1, 2, 3 -> 32, 64, 96, 128
                # method, file, seed_list, type, dec_rate, number
                fro = 'exp'
                m, s = get_mean_and_std(fro, method[i], file_name, seed_list, type, decreasing_rate, j)
                vals.append(m)
                vars.append(s)

            plt.errorbar(orders, vals, yerr = vars, ls = ls_dic[method[i]], linewidth=3.5, color=color_dic[method[i]], label= label_dic[method[i]], marker=MarkerStyle(marker_dic[method[i]], fillstyle='full'), elinewidth = 3 ,capsize = 13, markersize = 17, alpha = 0.8)

        plt.xlabel("# Training Samples $N^{0}$", fontsize=25)
        if metrix == 'rmse':
            plt.ylabel("RMSE", fontsize = 25)
        elif metrix == 'nll':
            plt.ylabel("NLL", fontsize = 25)
        elif metrix == 'nrmse':
            plt.ylabel("nRMSE", fontsize = 25)
        ax = plt.gca()
        plt.gcf().subplots_adjust(top=0.93,
                                    bottom=0.2,
                                    left=0.18,
                                    right=0.95,
                                    hspace=0.2,
                                    wspace=0.2)

        if file_name == "Poisson_mfGent_v5":
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%2d'))
        else:
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))

        plt.xticks(orders)
        plt.tick_params(axis='both', labelsize=25)

        legend_list = ['Heat_mfGent_v5', 'TopOP_mfGent_v6', 'plasmonic2_MF']
        if file_name in legend_list:
                if file_name == 'TopOP_mfGent_v6' and decreasing_rate == 0.75:
                    pass
                elif file_name == 'plasmonic2_MF' and decreasing_rate == 0.75:
                    pass
                else:
                    plt.legend(loc='upper right', fontsize=25)
        plt.grid()

        folder_path = os.path.join(sys.path[-1], 'paint', 'subset', 'graphs','dec_' + str(decreasing_rate), str(metrix))
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        fig_file = os.path.join(folder_path, file_name + '_dec_'+ str(decreasing_rate) + '.eps')
        plt.savefig(fig_file, bbox_inches = 'tight')

if __name__ == '__main__':
    # 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
    errorbar_plot_rmse(data_name_list = ['Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'],
                      decreasing_rate = 0.5,
                      metrix = 'rmse', 
                      seed_list = [0,1,2,3,4])

        

