import sys
import os

sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from src.experments.sort_performance.sort_performance import load_fss, create_random_fss, prepare_subsets, \
    load_metafe, load_dataset, load_dataset_class, metafe_subsets_eval, KNN_subsets_eval, DT_subsets_eval, \
    LR_subsets_eval, SVC_subsets_eval, cal_rbo


def abs_path(path):
    return os.path.join(os.path.split(__file__)[0], path)


train_set, valid_set, test_set = load_dataset(load_dataset_class("Gisette"))

# random samples
real_res = prepare_subsets(train_set, valid_set, test_set, create_random_fss(5000),
                           abs_path('./result/gisette_select50_randomsample.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]

svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/svc_random.yaml'))
svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]

knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/knn_random.yaml'))
knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/dt_random.yaml'))
dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/lr_random.yaml'))
lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]


# [5000,2]: 0.8498325226409793  [5000, 32, 2]: 0.8695063179941706   [5000, 64, 32, 2]: 0.8769790395919849
#                                                                                256e: 0.8978377638918487
#                                                                            256e+128: 0.853186954002595
# k_start
# 512: 0.8912423116699992       256: 0.875993135406819    1024: 0.8921157725293364

for path_name in os.listdir('../../../result/gisette/'):
    dir_path = os.path.join('../../../result/gisette/', path_name)
    if not os.path.exists(os.path.join(dir_path, 'metafe_model_pre_train_fss.pth')):
        continue
    metafe = load_metafe(dir_path, 'metafe_model_pre_train_fss.pth')
    metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
    metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]

    res = []
    for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
        res.append(cal_rbo(real_f1, to_cal_res))
    print(dir_path)
    print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
    print(res)

# metafe = load_metafe('../../../result/gisette/em256_6432_2_noattention', 'metafe_model_pre_train_fss.pth')
# metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
# metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
#
# res = []
# for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
#     res.append(cal_rbo(real_f1, to_cal_res))
# print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
# print(res)




# # converaged samples
# fss = load_fss('../../../result/gisette/em256_6432_train_from_pre/fss_model_iter_63_final.pth')
# real_res = prepare_subsets(train_set, valid_set, test_set, fss,
#                            abs_path('./result/qsar_select50_converaged_em256_6432_frompre.yaml'),
#                            num=100, k=50, hidden_layers=[64, 32, 2],
#                            batch_size=4096)
# subsets, real_results = list(zip(*real_res['data']))
# real_f1 = [item["f1"] for item in real_results]
#
#
# knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/knn_converage_em256_6432_frompre.yaml'))
# knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]
#
# dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/dt_converage_em256_6432_frompre.yaml'))
# dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]
#
# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/lr_converage_em256_6432_frompre.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]
#
# metafe = load_metafe('../../../result/gisette/em256_6432_train_from_pre', 'metafe_model_iter_63_final.pth')
# metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe)
# metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
#
# res = []
# for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1]:
#     res.append(cal_rbo(real_f1, to_cal_res))
# print("metafe_f1, knn_f1, dt_f1, lr_f1")
# print(res)
#
#
# # em256_6432_2 converaged samples
# fss = load_fss('../../../result/gisette/em256_6432_2/fss_model_iter_53_final.pth')
# real_res = prepare_subsets(train_set, valid_set, test_set, fss,
#                            abs_path('./result/qsar_select50_converaged_em256_6432_2.yaml'),
#                            num=100, k=50, hidden_layers=[64, 32, 2],
#                            batch_size=4096)
# subsets, real_results = list(zip(*real_res['data']))
# real_f1 = [item["f1"] for item in real_results]
#
#
# knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/knn_converage_em256_6432_2.yaml'))
# knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]
#
# dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/dt_converage_em256_6432_2.yaml'))
# dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]
#
# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/lr_converage_em256_6432_2.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]
#
# metafe = load_metafe('../../../result/gisette/em256_6432_2', 'metafe_model_iter_53_final.pth')
# metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe)
# metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
#
# res = []
# for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1]:
#     res.append(cal_rbo(real_f1, to_cal_res))
# print("metafe_f1, knn_f1, dt_f1, lr_f1")
# print(res)
#

# # em256_6432_2 converaged  no attention
# fss = load_fss('../../../result/gisette/em256_6432_2_noattention/fss_model_iter_53_final.pth')
# real_res = prepare_subsets(train_set, valid_set, test_set, fss,
#                            abs_path('./result/gisette_select50_converage_em256_6432_2_noattention.yaml'),
#                            num=100, k=50, hidden_layers=[64, 32, 2],
#                            batch_size=4096)
# subsets, real_results = list(zip(*real_res['data']))
# real_f1 = [item["f1"] for item in real_results]
#
# knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/knn_converage_em256_6432_2_noattention.yaml'))
# knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]
#
# dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/dt_converage_em256_6432_2_noattention.yaml'))
# dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]
#
# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/lr_converage_em256_6432_2_noattention.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]
#
# svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/svc_converage_em256_6432_2_noattention.yaml'))
# svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]
#
# metafe = load_metafe('../../../result/gisette/em256_6432_2_noattention', 'metafe_model_iter_53_final.pth')
# metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
# metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
#
# res = []
# for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
#     res.append(cal_rbo(real_f1, to_cal_res))
# print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
# print(res)



# manual
fss = load_fss('../../../result/gisette/em256_6432_train_manual/fss_model_iter_9.pth')
real_res = prepare_subsets(train_set, valid_set, test_set, fss,
                           abs_path('./result/gisette_select50_manual_iter9.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]

knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/knn_converage_manual_iter9.yaml'))
knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/dt_converage_manual_iter9.yaml'))
dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/lr_converage_manual_iter9.yaml'))
lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]

svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/gisette/svc_converage_manual_iter9.yaml'))
svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]

metafe = load_metafe('../../../result/gisette/em256_6432_train_manual', 'metafe_model_iter_9.pth')
metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, 'cuda')
metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]

res = []
for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
    res.append(cal_rbo(real_f1, to_cal_res))
print("metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1")
print(res)
