import os
import sys

sys.path.append('../')
import random
import numpy as np
import torch
import torch.nn as nn
from src.config import config as cfg
from src.config import YamlLoader
# from src.meta_learning.model import FeatureSelection
from trainer.utils import DatasetShell
from src.util.utils import DataLoaderSampler
from torch.utils.data import DataLoader
from src.dataloaders import SelectedDataset
from src.verify.verify import verification, ntimes_verification
from src.util.utils import append_to_txt
import argparse
from src.verify.trainer.utils.setshell import DatasetShell
import ruamel.yaml as yaml
from src.baselines import FClassif, MutualInfoClassif, RFESelect, Lasso, MRMRSelect, DTRFESelect

train_set = None
valid_set = None
test_set = None


def init_sys():
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_dataset(dataset: str):
    if dataset == 'Forest':
        from dataloaders.forest import Forest
        return Forest
    elif dataset == 'P53':
        from dataloaders.p53 import P53
        return P53
    elif dataset == 'QSAR':
        from dataloaders.qsar import QSAR
        return QSAR
    elif dataset == 'Gisette':
        from dataloaders.gisette import Gisette
        return Gisette
    elif dataset == 'Shopping':
        from src.dataloaders.shopping import ShoppingGender
        return ShoppingGender
    else:
        raise FileNotFoundError(f'No dataset named {dataset}')


def calculate(args, cfg, bestsubset):
    init_sys()
    dataset = load_dataset(cfg.dataset)
    global train_set, valid_set, test_set

    train_set, test_set = dataset('train'), dataset('test')
    try:
        valid_set = dataset(['valid1', 'valid2'])
    except Exception:
        valid_set = dataset('valid')

    selected_ids = np.array(list(bestsubset)).astype('int')

    init_sys()
    train_dataset = SelectedDataset(train_set, select_ids=selected_ids)
    valid_dataset = SelectedDataset(valid_set, select_ids=selected_ids)
    test_dataset = SelectedDataset(test_set, select_ids=selected_ids)

    print(f'test  len: {len(selected_ids)}')
    print(f'{train_dataset[0]}')

    if torch.cuda.is_available():
        device = args.device
    else:
        device = torch.device('cpu')
    print('dataset:', train_dataset)

    if cfg.dataset == 'Shopping':
        print('using Shopping hyper parameters:')
        res = ntimes_verification(1, train_dataset, valid_dataset, test_dataset,
                                  (len(selected_ids), 64, 32, 2), 1024 * 32, 100, device=args.device, verbose=2)
    else:
        res = ntimes_verification(args.test_num, train_dataset, valid_dataset, test_dataset, [len(selected_ids)] + cfg.metafe.dims[1:],
                                  1024, 1000, device=device, verbose=0)
    return res


def main_subset(args, cfg):
    save_file = YamlLoader(os.path.join(args.dir, 'test_subset_result.yaml'))
    if not save_file.exists():
        save_file.update({
            "best_subset": {},
            "best_eval_subset": {},
            "best_test_subset": {},
            "subsets": [],
        })

    if args.file == "best_selects":
        with open(os.path.join(args.dir, f'{args.file}.yaml')) as f:
            best_subsets = yaml.load(f, Loader=PrettySafeLoader)
    else:
        return

    if not args.useall:
        score, subset = max(best_subsets['data'].items(), key=lambda x: x[0])
        print(f"test score:<{score}>  subset: {subset}")
        res = calculate(args, cfg, subset)
        save_file.data["best_subset"] = {
            "subset": subset,
            "metafe_score": score,
            "eval_res": res,
        }
        save_file.save()
    else:
        for i, (score, subset) in enumerate(sorted(best_subsets['data'].items(), key=lambda x: x[0], reverse=True)):
            print(f"{i}/{len(best_subsets)}  test score:<{score}>  subset: {subset}")
            res = calculate(args, cfg, subset)
            save_file.data['subsets'].append({
                "subset": subset,
                "metafe_score": score,
                "eval_res": res,
            })
            save_file.save()
        best_subset = max(save_file.data['subsets'], key=lambda x: x["metafe_score"])
        best_eval_subset = max(save_file.data['subsets'], key=lambda x: x["eval_res"]["val_f1"])
        best_test_subset = max(save_file.data['subsets'], key=lambda x: x["eval_res"]["f1"])
        save_file.data["best_eval_subset"] = best_eval_subset
        save_file.data["best_subset"] = best_subset
        save_file.data["best_test_subset"] = best_test_subset
        save_file.save()


    # print('result:', res)
    #
    # with open(os.path.join(args.dir, 'test_result.txt'), 'a+') as f:
    #     f.write('******************************************************')
    #     # f.write(f'---max_iter:{args.max_iter}-----------------------------------\n')
    #     f.write(f'---file:{args.file}-----------------------------------\n')
    #     f.write(f"selected_ids: {selected_ids}\n")
    #     f.write(f'{res}\n')
    #     f.write('******************************************************\n\n')


class PrettySafeLoader(yaml.SafeLoader):
    def construct_python_tuple(self, node):
        return tuple(self.construct_sequence(node))


PrettySafeLoader.add_constructor(
    u'tag:yaml.org,2002:python/tuple',
    PrettySafeLoader.construct_python_tuple)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir', type=str, default='../result/qsar/test1')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--test_num', type=int, default=30)
    parser.add_argument('--useall', type=bool, default=False)
    parser.add_argument('--file', type=str, default='best_selects')

    # parser.add_argument('-e', '--embedding', type=str, default="128 64 32")
    args = parser.parse_args()
    cfg.load_config(os.path.join(args.dir, 'config.yaml'))
    main_subset(args, cfg)

