import sys
import os

sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from src.experments.sort_performance.sort_performance import load_fss, create_random_fss, prepare_subsets, \
    load_metafe, load_dataset, load_dataset_class, metafe_subsets_eval, KNN_subsets_eval, DT_subsets_eval, LR_subsets_eval, cal_rbo


def abs_path(path):
    return os.path.join(os.path.split(__file__)[0], path)


train_set, valid_set, test_set = load_dataset(load_dataset_class("QSAR"))

# random samples
real_res = prepare_subsets(train_set, valid_set, test_set, create_random_fss(1024),
                           abs_path('./result/qsar_select100_randomsample.yaml'),
                           num=100, k=100, hidden_layers=[64, 32, 2],
                           batch_size=4096)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]



knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/knn_random.yaml'))
knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/dt_random.yaml'))
dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/lr_random.yaml'))
lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]


# lr128: 0.6028435781194115   lr1: 0.621203440959818   lr16: 0.5995717881053801  no:0.5284775660941631
# lr1_2:                    lr256: 0.5980914697994103  lr8:                      embedding256: 0.6147837013797305

metafe = load_metafe('../../../result/qsar/test_lr_256_embedding', 'metafe_model_pre_train_fss.pth')
metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe)
metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]

res = []
for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1]:
    res.append(cal_rbo(real_f1, to_cal_res))
print("metafe_f1, knn_f1, dt_f1, lr_f1")
print(res)


# converaged samples
fss = load_fss('../../../result/qsar/test_lr_256_embedding/fss_model_iter_146_final.pth')
real_res = prepare_subsets(train_set, valid_set, test_set, fss,
                           abs_path('./result/qsar_select100_converaged.yaml'),
                           num=100, k=100, hidden_layers=[64, 32, 2],
                           batch_size=4096)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]


knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/knn_converage.yaml'))
knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/dt_converage.yaml'))
dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/qsar/lr_converage.yaml'))
lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]

metafe = load_metafe('../../../result/qsar/test_lr_256_embedding', 'metafe_model_iter_146_final.pth')
metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe)
metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]

res = []
for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1]:
    res.append(cal_rbo(real_f1, to_cal_res))
print("metafe_f1, knn_f1, dt_f1, lr_f1")
print(res)
