import pandas as pd
import os
import numpy as np
from matplotlib import pyplot as plt
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

def get_data(method, file, exp_type, number):
    path  = os.path.join(sys.path[-1], 'exp_time', method, file, exp_type, 'result_' + str(number) + '.csv')
    data = pd.read_csv(path)
    return data

def makedir(path):
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        return True
    else:
        return False


def errorbar_plot_rmse(data_name_list, train_sample_num, exp_type, paint_type, seed_list):

    color_dic = {'conAR':'#DC143C', 'dmfal':'#2ca02c', 'ifc':'#1f77b4', 'ar':'#ff7f0e', 'resgp':'#0033ff'}
    marker_dic = {'conAR':"o", 'dmfal':"s", 'ifc':"^", 'ar':"v", 'resgp':"P"}
    label_dic = {'conAR':'Ours', 'dmfal':'MF-BNN', 'ifc':'IFC-GPT', 'ar':'AR', 'resgp':'resGP'}
    
    method = ['ifc', 'conAR']

    for data_name in data_name_list:
        plt.figure(figsize=(10, 6), dpi=100)

        for i in range(len(method)):

                val = []
                for seed in seed_list:
                    data = pd.DataFrame(get_data(method[i], data_name, exp_type, seed))
                    sorted_data = data.sort_values(by='time', ascending=True)
                    val.append(sorted_data['rmse'].values)
                    # print(max(sorted_data['time'].values))
                vals = np.array(val).mean(0)
                yerr = np.array(val).std(0)
                time = sorted_data['time'].values

                if paint_type == 'log':
                    plt.plot(np.log10(time), vals, ls = 'solid', color = color_dic[method[i]], label= label_dic[method[i]], marker = marker_dic[method[i]], markersize=4)
                    plt.fill_between(np.log10(time), np.array(vals) - np.array(yerr), np.array(vals) + np.array(yerr), color = color_dic[method[i]], alpha = 0.2)
                else:
                    plt.plot(time, vals, ls = 'solid', color = color_dic[method[i]], label= label_dic[method[i]], marker = marker_dic[method[i]], markersize=4)
                    plt.fill_between(time, np.array(vals) - np.array(yerr), np.array(vals) + np.array(yerr), color = color_dic[method[i]], alpha = 0.2)
                
        
        if paint_type == 'log':
            plt.xlabel("$log_{10}$(Time)", fontsize=25)
        else:
            plt.xlabel("Time", fontsize=25)
        plt.ylabel("RMSE", fontsize = 25)

        ax = plt.gca()
        plt.gcf().subplots_adjust(top=0.93,
                                    bottom=0.2,
                                    left=0.18,
                                    right=0.95,
                                    hspace=0.2,
                                    wspace=0.2)
        
        plt.tick_params(axis='both', labelsize=25)
        
        plt.legend(loc='upper right', fontsize=25)
            
        plt.grid()
        
        folder_path = os.path.join(sys.path[-1], 'paint', 'time', 'graphs', data_name)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        if paint_type == 'log':
            fig_file = os.path.join(folder_path, data_name + '_' + 'ifc_log' + '.png')
        else:
            fig_file = os.path.join(folder_path, data_name + '_' + 'ifc' + '.png')
        
        plt.savefig(fig_file, bbox_inches = 'tight')

if __name__ == '__main__':

    # 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
    errorbar_plot_rmse(data_name_list = ['Heat_mfGent_v5'],
                      train_sample_num = 128,
                      exp_type = 'dec_0.5',
                      paint_type = 'log',
                      seed_list = [0,1,2,3,4])

        

