from matplotlib.lines import Line2D
import xlrd
import pandas as pd
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.markers import MarkerStyle
from matplotlib.pyplot import MultipleLocator

def get_data(file, type):
    data = pd.read_csv(file)
    target_list = data[type]
    return target_list.values

def get_mean_and_std(method, file, data, interp, type, n):
    m = []
    s = []
    val = []
    seed = ['0', '1', '2', '3', '4']
    if method == 'dmfal':
        for i in seed:
            f = "exp/" + method + "/" + file + "/" + data + "_Seed[" + i + "]_" + interp + ".csv"
            val.append(get_data(f, 'rmse'))
    else:
        for i in seed:
            f = "exp/" + method + "/" + file + "/" + data + "_Seed[" + i + "]_" + interp + ".csv"
            val.append(get_data(f, type))

    for i in range(n):
        temp = []
        for j in range(len(seed)):
            temp.append(val[j][i])
        temp = np.array(temp)
        m.append(temp.mean())
        s.append(temp.std())

    return np.array(m), np.array(s)

def makedir(path):
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        return True
    else:
        return False

if __name__ == '__main__':

    # 32 r2
    # method = ['ResGP', 'LarGP', 'NAR','DC_cigp', 'dmfal', 'SGAR', 'GAR']
    # file_name = 'Heat_mfGent_v5_m2h_32_int'
    # data_name = 'Heat_mfGent_v5'
    # interp = 'Interp[True]'
    # max_num = 32

    # 64 r2
    # method = ['NAR','DC_cigp', 'dmfal', 'SGAR', 'GAR']
    # file_name = 'Heat_mfGent_v5_m2h_32'
    # data_name = 'Heat_mfGent_v5'
    # interp = 'Interp[False]'
    # max_num = 32

    # 128 rmse
    method = ['dmfal', 'SGAR', 'GAR']
    file_name = 'Heat_mfGent_v5_sub_true'
    data_name = 'Heat_mfGent_v5'
    interp = 'Interp[True]'
    max_num = 32

    dic= {32: 4, 64:5, 128:6}

    color_dic = {'GAR':'#DC143C', 'dmfal':'#2ca02c', 'SGAR':'#1f77b4', 'LarGP':'#ff7f0e', 'ResGP':'#8c564b', 'NAR':'#708090', 'DC_cigp':'#17becf'}
    marker_dic = {'GAR':"o", 'dmfal':"s", 'SGAR':"^", 'LarGP':"v", 'ResGP':"P", 'NAR':"d", 'DC_cigp':"h"}
    ls_dic = {'GAR':'dashed', 'SGAR':'dashed', 'dmfal':'solid', 'LarGP':'solid', 'ResGP':'solid', 'NAR':'solid', 'DC_cigp':'solid'}
    lw_dic = {'GAR':5, 'SGAR':5, 'dmfal':2, 'LarGP':2, 'ResGP':2, 'NAR':2, 'DC_cigp':2}
    label_dic = {'GAR':'GAR', 'dmfal':'MF-BNN', 'SGAR':'CIGAR', 'LarGP':'AR', 'ResGP':'ResGP', 'NAR':'NAR', 'DC_cigp':'DC-I'}

    marker = ["o", "s", "^", "v", "*", "d", "h", "p", "x", "+"]
    color = ['#DC143C', '#1f77b4', '#2ca02c', '#ff7f0e', '#8c564b', '#708090', '#7f7f7f', '#000000', '#17becf']  # bcbd22
    orders = [2 ** (i + 2) for i in range(dic[max_num])]
    vals = []
    vars = []
    for i in range(len(method)):
        m, s = get_mean_and_std(method[i], file_name, data_name , interp, 'rmse', dic[max_num])
        vals.append(m)
        vars.append(s)
        plt.errorbar(orders, vals[i], yerr = vars[i], ls = ls_dic[method[i]], linewidth=3.5, color=color_dic[method[i]], label= label_dic[method[i]], marker=MarkerStyle(marker_dic[method[i]], fillstyle='full'), elinewidth = 3 ,capsize = 8, markersize = 12, alpha = 0.8)
        # plt.plot(orders, vals[i], linewidth=2, color=color[i], label=method[i], marker=marker[i])
        # plt.fill_between(orders, vals[i] - vars[i] * ratio, vals[i] + vars[i] * ratio, alpha=0.001, color=color[i])

    plt.xlabel("#HF Samples", fontsize=20)
    plt.ylabel("RMSE", fontsize = 20)
    # plt.ylim((0, 0.15))
    plt.ylim((0, 0.09))
    ax = plt.gca()
    plt.xticks(orders)
    plt.tick_params(axis='both', labelsize=18)
    # plt.legend(loc='upper right', fontsize=17)
    # plt.title(file_name,fontsize = 12)
    plt.grid()
    # plt.show()

    # makedir(r"fig")
    fig_file = r"pic_final/fig_" + file_name + ".eps"
    plt.savefig(fig_file, bbox_inches = 'tight')