import pandas as pd
import os
from matplotlib import pyplot as plt
from matplotlib.markers import MarkerStyle
import matplotlib.ticker as mtick
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

def get_data(method, file, number):
    path  = os.path.join(sys.path[-1], 'exp_cost', method, file, 'cost', 'result_' + str(number) + '.csv')
    data = pd.read_csv(path)
    return data

def makedir(path):
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        return True
    else:
        return False


def errorbar_plot_rmse(data_name_list, metrix, seed_num):

    color_dic = {'conAR':'#DC143C', 'dmfal':'#2ca02c', 'ar':'#ff7f0e', 'resgp':'#0033ff'}
    marker_dic = {'conAR':"o", 'dmfal':"s", 'ar':"v", 'resgp':"P"}
    ls_dic = {'conAR':'dashed', 'dmfal':'solid', 'ar':'solid', 'resgp':'solid'}
    label_dic = {'conAR':'Ours', 'dmfal':'MF-BNN', 'ar':'AR', 'resgp':'resGP'}

    method = ['dmfal', 'resgp', 'ar', 'conAR']
   
    for file_name in data_name_list:
        type = metrix
        plt.figure(figsize=(18, 6), dpi=100)

        for i in range(len(method)):
            data = pd.DataFrame(get_data(method[i], file_name, seed_num))
            sorted_data = data.sort_values(by='cost', ascending=True)
            orders = sorted_data['cost'].values

            # plt.plot(orders, sorted_data[type].values,ls = ls_dic[method[i]], linewidth=3.5, color=color_dic[method[i]], label= label_dic[method[i]], alpha = 0.8)

            plt.plot(orders, sorted_data[type].values, ls = ls_dic[method[i]], linewidth=3.5, color=color_dic[method[i]], label= label_dic[method[i]], marker=MarkerStyle(marker_dic[method[i]], fillstyle='full'), markersize = 9, alpha = 0.8)

        plt.xlabel("# Cost", fontsize=25)
        if metrix == 'rmse':
            plt.ylabel("RMSE", fontsize = 25)
        elif metrix == 'nll':
            plt.ylabel("NLL", fontsize = 25)
        elif metrix == 'nrmse':
            plt.ylabel("nRMSE", fontsize = 25)

        ax = plt.gca()
        plt.gcf().subplots_adjust(top=0.93,
                                    bottom=0.2,
                                    left=0.18,
                                    right=0.95,
                                    hspace=0.2,
                                    wspace=0.2)

        if file_name == "Poisson_mfGent_v5":
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%2d'))
        else:
            ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
        

        plt.xticks([930, 1258, 1586, 1914, 2242, 2570, 2898, 3226, 3554])
        plt.tick_params(axis='both', labelsize=25)
        if file_name == "Heat_mfGent_v5":
            plt.legend(loc='upper right', fontsize=25)
            
        # plt.grid()
        
        folder_path = os.path.join(sys.path[-1], 'paint', 'cost', 'graphs', str(metrix))
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        fig_file = os.path.join(folder_path, file_name + '_cost_' + str(seed_num) + '.eps')
        plt.savefig(fig_file, bbox_inches = 'tight')

if __name__ == '__main__':

    # 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
    errorbar_plot_rmse(data_name_list = ['Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'],
                      metrix = 'rmse', 
                      seed_num = 50)

        

