import sys
import os
import torch

sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from src.experments.sort_performance.sort_performance import load_fss, create_random_fss, prepare_subsets, \
    load_metafe, load_dataset, load_dataset_class, metafe_subsets_eval, KNN_subsets_eval, DT_subsets_eval, \
    SVC_subsets_eval, LR_subsets_eval, cal_rbo


def abs_path(path):
    return os.path.join(os.path.split(__file__)[0], path)


# TODO: 由于已生成缓存文件，加载的数据调整为1
train_set, valid_set, test_set = load_dataset(load_dataset_class("Shopping"), 1)

# random samples
real_res = prepare_subsets(train_set, valid_set, test_set, create_random_fss(21912),
                           abs_path('./result/shopping_select100_randomsample.yaml'),
                           num=100, k=50, hidden_layers=[64, 32, 2],
                           batch_size=4096, n_iter=1, verbose=1, lr=0.01)
subsets, real_results = list(zip(*real_res['data']))
real_f1 = [item["f1"] for item in real_results]

train_set, valid_set, test_set = load_dataset(load_dataset_class("Shopping"), 10000)

# time: 7930.96
knn_result = KNN_subsets_eval(train_set, valid_set, subsets, abs_path('./result/shopping/knn_random2.yaml'))
knn_f1 = [item['f1'] for item in list(zip(*knn_result['data']))[1]]

# time: 3715.88
dt_result = DT_subsets_eval(train_set, valid_set, subsets, abs_path('./result/shopping/dt_random2.yaml'))
dt_f1 = [item['f1'] for item in list(zip(*dt_result['data']))[1]]

# time: 3722.36
lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_ path('./result/shopping/lr_random2.yaml'))
lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]

svc_result = SVC_subsets_eval(train_set, valid_set, subsets, abs_path('./result/shopping/svc_random2.yaml'))
svc_f1 = [item['f1'] for item in list(zip(*svc_result['data']))[1]]

# lr_result = LR_subsets_eval(train_set, valid_set, subsets, abs_path('./result/shopping/lr_random.yaml'))
# lr_f1 = [item['f1'] for item in list(zip(*lr_result['data']))[1]]

# time(10000): 31.26(GPU) 375.85(CPU)
# knn_f1, dt_f1, lr_f1: 0.45518389706093265, 0.4290450457026555, 0.5034335611801741
# att_e128: 0.5289755446375323              test_att_e512_100(1): 0.5337302656635394, 0.4334448008665573 (v2,v3): 0.727194663497636
# 0.47564611675050084


# '../../../result/shopping/test_e512_o100_v3'
metafe = load_metafe('../../../result/shopping/kstart_256_new', 'metafe_model_pre_train_fss.pth')
metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, torch.device('cuda'))
metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
res = []
for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1, svc_f1]:
    res.append(cal_rbo(real_f1, to_cal_res))
print("metafe_f1, knn_f1, dt_f1, lr_f1")
print(res)

for dir_part_name in os.listdir('../../../result/shopping/'):
    dir_name = os.path.join('../../../result/shopping/', dir_part_name)
    if not os.path.exists(os.path.join(dir_name, 'metafe_model_pre_train_fss.pth')):
        continue
    # '../../../result/shopping/test_e512_o100_v3'
    metafe = load_metafe(dir_name, 'metafe_model_pre_train_fss.pth')
    metafe_result = metafe_subsets_eval(valid_set, subsets, None, metafe, torch.device('cpu'))
    metafe_f1 = [item['f1'] for item in list(zip(*metafe_result['data']))[1]]
    print(dir_name)
    res = []
    for to_cal_res in [metafe_f1, knn_f1, dt_f1, lr_f1]:
        res.append(cal_rbo(real_f1, to_cal_res))
    print("metafe_f1, knn_f1, dt_f1, lr_f1")
    print(res)
