import sys
import os
import matplotlib.pyplot as plt


sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from src.experments.sort_performance.sort_performance import load_fss, create_random_fss, prepare_subsets, \
    load_metafe, load_dataset, load_dataset_class, metafe_subsets_eval, KNN_subsets_eval, DT_subsets_eval, \
    LR_subsets_eval, SVC_subsets_eval, cal_rbo


def abs_path(path):
    return os.path.join(os.path.split(__file__)[0], path)


# TODO: 由于已生成缓存文件，加载的数据调整为1
train_set, valid_set, test_set = load_dataset(load_dataset_class("Shopping"), 10000)

# random samples
real_res = prepare_subsets(train_set, valid_set, test_set, create_random_fss(21912),
                           abs_path('./result/shopping_select100_randomsample.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096, n_iter=1, verbose=1, lr=0.01)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]

# train_set, valid_set, test_set = load_dataset(load_dataset_class("Shopping"), 10000)


# svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/svc_random.yaml'))
# svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]
#
# knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/knn_random.yaml'))
# knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]
#
# dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/dt_random.yaml'))
# dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]
#
# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/lr_random.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]


# converaged samples
fss = load_fss('../../../result/shopping/kstart_256_new_v3_b512_lr01/fss_model_iter_96.pth')
con_real_res = prepare_subsets(train_set, valid_set, test_set, fss,
                           abs_path('./result/shopping_select50_manual_iter96.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096, n_iter=1)
con_subsets, con_real_results = list(zip(*con_real_res['data']))
con_real_f1 = [item["f1"] for item in con_real_results]


base_dir = '../../../result/shopping/kstart_256_new_v3_b512_lr01'


svc_result = SVC_subsets_eval(train_set, valid_set, con_subsets, abs_path('./result/shopping/svc_random_iter_96.yaml'))
con_svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]

knn_result = KNN_subsets_eval(train_set, valid_set, con_subsets, abs_path('./result/shopping/knn_random_iter_96.yaml'))
con_knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

dt_result = DT_subsets_eval(train_set, valid_set, con_subsets, abs_path('./result/shopping/dt_random_iter_96.yaml'))
con_dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

lr_result = LR_subsets_eval(train_set, valid_set, con_subsets, abs_path('./result/shopping/lr_random_iter_96.yaml'))
con_lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]


metafe = load_metafe('../../../result/shopping/kstart_256_new_v3_b512_lr01', 'metafe_model_iter_96.pth')
metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, torch.device('cuda'))
metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
res = []
for to_cal_res in [metafe_f1, con_knn_f1, con_dt_f1, con_lr_f1, con_svc_f1]:
    res.append(cal_rbo(con_real_f1, to_cal_res))
print("metafe_f1, knn_f1, dt_f1, lr_f1")
print(res)


def cal_metafe_rbo(model_name, real_f1):
    metafe = load_metafe(base_dir, model_name)
    metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
    metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
    # res = []
    # for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
    #     res.append(cal_rbo(real_f1, to_cal_res))
    # print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
    # print(res)
    return cal_rbo(real_f1, metafe_f1)


# y_random = [cal_metafe_rbo(f'metafe_model_iter_{i}.pth', real_f1) for i in (1, 2, 4, 8, 16, 32, 64, 96)]
y_random = [0.8731200585588056,
 0.8613557259225226,
 0.9042305088251466,
 0.7716765349292095,
 0.7662121455372847,
 0.7752824452176378,
 0.589788066830532,
 0.5755599958482694]    
# y_con = [cal_metafe_rbo(f'metafe_model_iter_{i}.pth', con_real_f1) for i in (1, 2, 4, 8, 16, 32, 64, 96)]
y_con = [0.5445774674692959,
 0.5351021458979488,
 0.5993332439281731,
 0.5743281517175438,
 0.5982294941830645,
 0.6286779408076039,
 0.6115737266397194,
 0.652301544616299]

fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
line1 = ax.plot(y_random, marker='*', color='b', label='Uniform Distribution')
ax.plot([0, 7], [y_random[0], y_random[-1]], color='b', alpha=0.5, linestyle='--')
ax.set_ylim(0.2, 0.92)

ax2 = ax.twinx()
ax2.set_ylim(0.45, 0.75)
line2 = ax2.plot(y_con, marker='o', color='r', label='Converaged Distribution')
ax2.plot([0, 7], [y_con[0], y_con[-1]], color='r', alpha=0.5, linestyle='--')
ax.set_ylabel('(Uniform)', fontsize=14, fontdict={'family': "Times New Roman"})
ax2.set_ylabel('(Converaged)', fontsize=14, fontdict={'family': "Times New Roman"})
plt.xticks([])
ax.set_xlabel('Iteration', fontsize=14, fontdict={'family': "Times New Roman"})
ax.set_title('Shopping Dataset', fontsize=15, fontdict={'family': "Times New Roman"})
plt.grid(axis='y')
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax.legend(lines, labels, prop={'family': "Times New Roman", 'style': 'normal'})
plt.tight_layout()
