import os
import sys

sys.path.append('../')
import random
import numpy as np
import torch
import torch.nn as nn
from src.config import config as cfg
from src.config import YamlLoader
# from src.meta_learning.model import FeatureSelection
from trainer.utils import DatasetShell
from src.util.utils import DataLoaderSampler
from torch.utils.data import DataLoader
from src.dataloaders import SelectedDataset
from src.verify.verify import verification, ntimes_verification
from src.util.utils import append_to_txt
import argparse
from src.verify.trainer.utils.setshell import DatasetShell
import ruamel.yaml as yaml
from src.baselines import FClassif, MutualInfoClassif, RFESelect, Lasso, MRMRSelect, DTRFESelect

train_set = None
valid_set = None
test_set = None


def init_sys():
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_dataset(dataset: str):
    if dataset == 'Forest':
        from dataloaders.forest import Forest
        return Forest
    elif dataset == 'P53':
        from dataloaders.p53 import P53
        return P53
    elif dataset == 'QSAR':
        from dataloaders.qsar import QSAR
        return QSAR
    elif dataset == 'Gisette':
        from dataloaders.gisette import Gisette
        return Gisette
    elif dataset == 'Shopping':
        from src.dataloaders.shopping import ShoppingGender
        return ShoppingGender
    else:
        raise FileNotFoundError(f'No dataset named {dataset}')


def calculate(device, dataset_name, bestsubset):
    init_sys()
    dataset = load_dataset(dataset_name)
    global train_set, valid_set, test_set

    train_set, test_set = dataset('train'), dataset('test')
    try:
        valid_set = dataset(['valid1', 'valid2'])
    except Exception:
        valid_set = dataset('valid')

    selected_ids = np.array(list(bestsubset)).astype('int')

    init_sys()
    train_dataset = SelectedDataset(train_set, select_ids=selected_ids)
    valid_dataset = SelectedDataset(valid_set, select_ids=selected_ids)
    test_dataset = SelectedDataset(test_set, select_ids=selected_ids)

    print(f'test  len: {len(selected_ids)}')
    print(f'{train_dataset[0]}')

    if torch.cuda.is_available():
        device = device
    else:
        device = torch.device('cpu')
    print('dataset:', train_dataset)

    if dataset_name == 'Shopping':
        print('using Shopping hyper parameters:')
        res = ntimes_verification(1, train_dataset, valid_dataset, test_dataset,
                                  (len(selected_ids), 64, 32, 2), 1024 * 32, 100, device=args.device, verbose=2)
    else:
        res = ntimes_verification(args.test_num, train_dataset, valid_dataset, test_dataset, [len(selected_ids)] + [64, 32],
                                  1024, 1000, device=device, verbose=0)
    return res


#
# def main_subset(args, cfg):
#     save_file = YamlLoader(os.path.join(args.dir, 'test_subset_result.yaml'))
#     if not save_file.exists():
#         save_file.update({
#             "best_subset": {},
#             "best_eval_subset": {},
#             "best_test_subset": {},
#             "subsets": [],
#         })
#
#     if args.file == "best_selects":
#         with open(os.path.join(args.dir, f'{args.file}.yaml')) as f:
#             best_subsets = yaml.load(f, Loader=PrettySafeLoader)
#     else:
#         return
#
#     if not args.useall:
#         score, subset = max(best_subsets['data'].items(), key=lambda x: x[0])
#         print(f"test score:<{score}>  subset: {subset}")
#         res = calculate(args, cfg, subset)
#         save_file.data["best_subset"] = {
#             "subset": subset,
#             "metafe_score": score,
#             "eval_res": res,
#         }
#         save_file.save()
#     else:
#         for i, (score, subset) in enumerate(sorted(best_subsets['data'].items(), key=lambda x: x[0], reverse=True)):
#             print(f"{i}/{len(best_subsets)}  test score:<{score}>  subset: {subset}")
#             res = calculate(args, cfg, subset)
#             save_file.data['subsets'].append({
#                 "subset": subset,
#                 "metafe_score": score,
#                 "eval_res": res,
#             })
#             save_file.save()
#         best_subset = max(save_file.data['subsets'], key=lambda x: x["metafe_score"])
#         best_eval_subset = max(save_file.data['subsets'], key=lambda x: x["eval_res"]["val_f1"])
#         best_test_subset = max(save_file.data['subsets'], key=lambda x: x["eval_res"]["f1"])
#         save_file.data["best_eval_subset"] = best_eval_subset
#         save_file.data["best_subset"] = best_subset
#         save_file.data["best_test_subset"] = best_test_subset
#         save_file.save()


class PrettySafeLoader(yaml.SafeLoader):
    def construct_python_tuple(self, node):
        return tuple(self.construct_sequence(node))


PrettySafeLoader.add_constructor(
    u'tag:yaml.org,2002:python/tuple',
    PrettySafeLoader.construct_python_tuple)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Gisette')
    # parser.add_argument('--subset', type=str, default='')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--test_num', type=int, default=30)
    parser.add_argument('--useall', type=bool, default=False)
    parser.add_argument('--file', type=str, default='best_selects')

    # parser.add_argument('-e', '--embedding', type=str, default="128 64 32")

    args = parser.parse_args()
    bestsubset = [17, 19, 24, 25, 38, 43, 61, 68, 74, 83, 96, 110, 122, 125, 131,
     133, 135, 145, 155, 200, 233, 235, 252, 262, 275, 280, 286, 290, 295, 297, 314,
     322, 338, 356, 364, 370, 372, 379, 391, 399, 401, 402, 412, 423, 424, 439, 445,
     456, 457, 467, 477, 481, 485, 490, 492, 498, 502, 511, 537, 557, 576, 599, 623,
     626, 631, 639, 660, 677, 681, 692, 694, 695, 696, 714, 722, 726, 729, 736, 752,
     760, 761, 776, 782, 786, 801, 817, 818, 826, 834, 836, 837, 868, 881, 885, 892,
     898, 949, 989, 1000, 1008, 1009, 1034, 1065, 1067, 1094, 1100, 1109, 1118, 1125,
     1131, 1143, 1145, 1175, 1180, 1183, 1184, 1204, 1209, 1212, 1218, 1221, 1225,
     1228, 1230, 1236, 1240, 1243, 1258, 1271, 1275, 1293, 1298, 1300, 1302, 1332,
     1333, 1345, 1349, 1359, 1365, 1381, 1387, 1394, 1404, 1431, 1457, 1469, 1471,
     1479, 1480, 1482, 1495, 1508, 1509, 1516, 1530, 1534, 1536, 1555, 1556, 1558,
     1565, 1567, 1588, 1614, 1626, 1632, 1638, 1640, 1644, 1654, 1660, 1669, 1670,
     1695, 1706, 1708, 1715, 1716, 1717, 1725, 1731, 1733, 1737, 1743, 1755, 1760,
     1776, 1791, 1805, 1818, 1843, 1847, 1855, 1869, 1873, 1883, 1885, 1891, 1922,
     1929, 1941, 1952, 1971, 1979, 1994, 1995, 2010, 2012, 2024, 2029, 2043, 2044,
     2068, 2079, 2091, 2101, 2102, 2104, 2116, 2155, 2160, 2168, 2169, 2175, 2183,
     2187, 2189, 2199, 2206, 2212, 2222, 2234, 2241, 2244, 2264, 2265, 2270, 2301,
     2304, 2325, 2343, 2353, 2378, 2384, 2388, 2437, 2438, 2448, 2451, 2455, 2456,
     2464, 2469, 2472, 2474, 2481, 2488, 2491, 2498, 2515, 2516, 2550, 2560, 2582,
     2587, 2616, 2617, 2630, 2648, 2654, 2655, 2662, 2675, 2677, 2707, 2725, 2729,
     2731, 2742, 2752, 2765, 2767, 2769, 2772, 2781, 2784, 2788, 2817, 2835, 2843,
     2863, 2870, 2874, 2883, 2893, 2921, 2923, 2927, 2931, 2964, 2996, 3002, 3035,
     3044, 3047, 3056, 3061, 3062, 3077, 3078, 3103, 3106, 3142, 3146, 3150, 3162,
     3171, 3191, 3196, 3197, 3218, 3240, 3244, 3251, 3252, 3254, 3269, 3277, 3281,
     3283, 3299, 3307, 3315, 3318, 3320, 3327, 3328, 3336, 3354, 3358, 3374, 3376,
     3379, 3391, 3397, 3398, 3418, 3423, 3433, 3446, 3454, 3463, 3478, 3480, 3496,
     3501, 3503, 3508, 3541, 3543, 3557, 3594, 3604, 3608, 3637, 3656, 3663, 3683,
     3688, 3694, 3699, 3700, 3707, 3719, 3721, 3725, 3746, 3755, 3759, 3765, 3797,
     3798, 3814, 3830, 3833, 3834, 3851, 3868, 3869, 3885, 3888, 3900, 3902, 3904,
     3907, 3914, 3930, 3943, 3954, 3956, 3975, 3990, 3999, 4000, 4008, 4026, 4027,
     4040, 4048, 4064, 4075, 4081, 4085, 4090, 4094, 4102, 4107, 4108, 4111, 4114,
     4130, 4143, 4146, 4153, 4168, 4177, 4180, 4183, 4187, 4193, 4198, 4202, 4228,
     4232, 4245, 4248, 4267, 4278, 4298, 4304, 4313, 4316, 4343, 4347, 4353, 4368,
     4369, 4374, 4384, 4386, 4391, 4392, 4403, 4409, 4412, 4422, 4424, 4425, 4429,
     4450, 4461, 4464, 4466, 4474, 4475, 4486, 4489, 4497, 4507, 4510, 4553, 4557,
     4572, 4573, 4575, 4577, 4583, 4586, 4594, 4598, 4605, 4608, 4610, 4619, 4631,
     4642, 4652, 4655, 4674, 4684, 4689, 4690, 4711, 4725, 4735, 4753, 4763, 4767,
     4774, 4779, 4782, 4790, 4814, 4828, 4832, 4833, 4844, 4856, 4858, 4862, 4869,
     4874, 4876, 4889, 4893, 4916, 4917, 4925, 4933, 4934, 4936, 4941, 4944, 4945,
     4947, 4949, 4955, 4963, 4967, 4970, 4976, 4979, 4981, 4984, 4989, 4990, 4991,
     4994, 4999]
    bestsubset = bestsubset[:400]
    calculate('cuda', 'Gisette', bestsubset)
    # cfg.load_config(os.path.join(args.dir, 'config.yaml'))
    # main_subset(args, cfg)

