import xlrd
import pandas as pd
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.pyplot import MultipleLocator

def get_data(file, type):
    data = pd.read_csv(file)
    target_list = data[type]
    return target_list.values

def get_d(method, file, data, interp, type, n):
    seed = 'None'
    if method == 'dmfal':
        f = "exp/" + method + "/" + file + "/" + data + "_Seed[" + seed + "]_" + interp + ".csv"
        return get_data(f, 'r2')
    else:
        f = "exp/" + method + "/" + file + "/" + data + "_Seed[" + seed + "]_" + interp + ".csv"
        return get_data(f, type)

def makedir(path):
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        return True
    else:
        return False

if __name__ == '__main__':

    # method = ['NAR','DC_cigp', 'dmfal', 'SGAR', 'GAR']
    # file_name = 'TopOP_mfGent_v5_128_int'
    # data_name = 'TopOP_mfGent_v5'
    # interp = 'Interp[True]'
    # max_num = 128

    
    method = ['NAR','DC_cigp', 'dmfal', 'SGAR', 'GAR']
    file_name = 'TopOP_mfGent_v5_64'
    data_name = 'TopOP_mfGent_v5_64'
    interp = 'Interp[False]'
    max_num = 64

    dic= {32: 4, 64:5, 128:6}
    color_dic = {'GAR':'#DC143C', 'dmfal':'#2ca02c', 'SGAR':'#1f77b4', 'LarGP':'#ff7f0e', 'ResGP':'#8c564b', 'NAR':'#708090', 'DC_cigp':'#17becf'}
    marker_dic = {'GAR':"o", 'dmfal':"s", 'SGAR':"^", 'LarGP':"v", 'ResGP':"P", 'NAR':"d", 'DC_cigp':"h"}
    ls_dic = {'GAR':'dashed', 'SGAR':'dashed', 'dmfal':'solid', 'LarGP':'solid', 'ResGP':'solid', 'NAR':'solid', 'DC_cigp':'solid'}
    lw_dic = {'GAR':5, 'SGAR':5, 'dmfal':2, 'LarGP':2, 'ResGP':2, 'NAR':2, 'DC_cigp':2}
    label_dic = {'GAR':'GAR', 'dmfal':'MF-BNN', 'SGAR':'CIGAR', 'LarGP':'AR', 'ResGP':'ResGP', 'NAR':'NAR', 'DC_cigp':'DC-I'}

    orders = [2 ** (i + 2) for i in range(dic[max_num])]
    vals = []
    vars = []
    for i in range(len(method)):
        temp = get_d(method[i], file_name, data_name , interp, 'rmse', dic[max_num])
        plt.plot(orders, temp, ls = ls_dic[method[i]], linewidth=3.5, color=color_dic[method[i]], label=method[i], marker=marker_dic[method[i]], markersize = 12)
        # plt.plot(orders, vals[i], linewidth=2, color=color[i], label=method[i], marker=marker[i])
        # plt.fill_between(orders, vals[i] - vars[i] * ratio, vals[i] + vars[i] * ratio, alpha=0.001, color=color[i])

    plt.xlabel("#HF Samples", fontsize=20)
    plt.ylabel("RMSE", fontsize = 20)
    plt.ylim((0.1, 0.5))
    ax = plt.gca()
    plt.xticks(orders)
    plt.tick_params(axis='both', labelsize=18)
    plt.legend(loc='upper right', fontsize=17)
    # plt.title(file_name,fontsize = 12)
    plt.grid()
    # plt.show()

    # makedir(r"fig")
    fig_file = r"pic_final/none/fig_" + file_name + '_none' + ".eps"
    plt.savefig(fig_file, bbox_inches = 'tight')