import sys
import os
import matplotlib.pyplot as plt


sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from src.experments.sort_performance.sort_performance import load_fss, create_random_fss, prepare_subsets, \
    load_metafe, load_dataset, load_dataset_class, metafe_subsets_eval, KNN_subsets_eval, DT_subsets_eval, \
    LR_subsets_eval, SVC_subsets_eval, cal_rbo


def abs_path(path):
    return os.path.join(os.path.split(__file__)[0], path)


train_set, valid_set, test_set = load_dataset(load_dataset_class("Gisette"))

# random samples
real_res = prepare_subsets(train_set, valid_set, test_set, create_random_fss(5000),
                           abs_path('./result/gisette_select50_randomsample.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]

# svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/svc_random.yaml'))
# svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]
#
# knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/knn_random.yaml'))
# knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]
#
# dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/dt_random.yaml'))
# dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]
#
# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/lr_random.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]


# converaged samples
fss = load_fss('../../../result/gisette/em256_6432_train_manual/fss_model_iter_9.pth')
con_real_res = prepare_subsets(train_set, valid_set, test_set, fss,
                           abs_path('./result/gisette_select50_manual_iter8.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096)
con_subsets, con_real_results = list(zip(*con_real_res['data']))
con_real_f1 = [item["f1"] for item in con_real_results]


base_dir = '../../../result/gisette/em256_6432_train_manual'


def cal_metafe_rbo(model_name, real_f1):
    metafe = load_metafe(base_dir, model_name)
    metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
    metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
    # res = []
    # for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
    #     res.append(cal_rbo(real_f1, to_cal_res))
    # print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
    # print(res)
    return cal_rbo(real_f1, metafe_f1)



y_random = [cal_metafe_rbo(f'metafe_model_iter_{i}.pth', real_f1) for i in range(10)]
# y_con = [cal_metafe_rbo(f'metafe_model_iter_{i}.pth', con_real_f1) for i in range(10)]
y_con = [0.5355109537898294,
         0.5494094367295154,
         0.6446463602785656,
         0.5762150189331763,
         0.5374165123928595,
         0.5742116578397077,
         0.5933653039694397,
         0.5696841047926285,
         0.6095887203776226,
         0.6679480682410082]

fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
line1 = ax.plot(y_random, color='b', marker='*', label='Uniform Distribution')
ax.plot([0, 9], [y_random[0], y_random[-1]], color='b', alpha=0.5, linestyle='--')
ax2 = ax.twinx()
line2 = ax2.plot(y_con, color='r', marker='o', label='Converaged Distribution')
ax2.plot([0, 9], [y_con[0], y_con[-1]], color='r', alpha=0.5, linestyle='--')
ax.grid(axis='y')

lines = line1 + line2
labels = [l.get_label() for l in lines]


ax.set_ylabel('(Uniform)', fontsize=14, fontdict={'family': "Times New Roman"})
plt.xticks([])
ax.set_xlabel('Iteration', fontsize=14, fontdict={'family': "Times New Roman"})

ax.set_ylim(0.83, 0.925)
ax2.set_ylim(0.46, 0.78)
ax2.set_ylabel('(Converged)', fontsize=14, fontdict={'family': "Times New Roman"})
ax.legend(lines, labels, prop={'family': "Times New Roman", 'style': 'normal'})
ax.set_title('Gisette Dataset', fontsize=15, fontdict={'family': "Times New Roman"})
# plt.legend(prop={'family': "Times New Roman", 'style': 'normal'})
plt.tight_layout()
