import os
import sys
import time
sys.path.append('../')
import random
import numpy as np
import torch
import torch.nn as nn
from src.config import config as cfg

from src.meta_learning.metafs3 import FeatureSelection, FSS
from src.meta_learning.model import NormalFE, MetaFE, EmbeddingAttention

from trainer.utils import DatasetShell, SubShell
from src.util.utils import DataLoaderSampler, change_data
from torch.utils.data import DataLoader
from src.dataloaders import SelectedDataset
from src.verify.verify import verification, ntimes_verification, test_subset
from src.verify.trainer.utils.setshell import DatasetShell
from src.util.utils import append_to_txt
import argparse
import ruamel.yaml as yaml
from src.config import ResultLogging
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import f1_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from src.util.rbo import RankingSimilarity
from multiprocessing import Pool
from src.meta_learning.model2 import MetaFE as KStartMetaFE, EmbeddingMLP


def init_sys():
    global device
    if cfg.sys.device == 'auto':
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        #         device = torch.device(cfg.sys.device)
        device = torch.device('cuda')

    random.seed(cfg.sys.seed)
    np.random.seed(cfg.sys.seed)
    torch.manual_seed(cfg.sys.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_dataset_class(dataset: str):
    if dataset == 'P53':
        from src.dataloaders.p53 import P53
        return P53
    elif dataset == 'QSAR':
        from src.dataloaders.qsar import QSAR
        return QSAR
    elif dataset == 'Gisette':
        from src.dataloaders.gisette import Gisette
        return Gisette
    elif dataset == 'Shopping':
        from src.dataloaders.shopping import ShoppingGender
        return ShoppingGender
    else:
        raise FileNotFoundError(f'No dataset named {dataset}')


def load_dataset(dataset_class, max_sample_size=None):
    _train_set, _test_set = dataset_class('train', max_sample_size=max_sample_size), \
                            dataset_class('test', max_sample_size=max_sample_size)
    try:
        _valid_set = dataset_class(['valid1', 'valid2'], max_sample_size=max_sample_size)
    except Exception:
        _valid_set = dataset_class('valid', max_sample_size=max_sample_size)
    return _train_set, _valid_set, _test_set


def _test_subset(inputs):
    subset, n_iter, train_dataset, valid_dataset, test_dataset, hidden_layers, batch_size, max_epoch, verbose, device, lr = inputs
    result = test_subset(subset=subset, n_iter=n_iter,
                         train_dataset=train_dataset, valid_dataset=valid_dataset,
                         test_dataset=test_dataset, hidden_layers=hidden_layers,
                         batch_size=batch_size, max_epoch=2000,
                         device='cuda', verbose=verbose, lr=lr)
    return subset, result


def _create_args(subsets, n_iter, train_dataset, valid_dataset, test_dataset, hidden_layers,
                 batch_size, max_epoch, verbose, device, lr):
    return [(subset, n_iter, train_dataset, valid_dataset, test_dataset, hidden_layers, batch_size,
             max_epoch, verbose, device, lr) for subset in subsets]


def prepare_subsets(train_set, valid_set, test_set, fss, save_path, num=100, k=50, hidden_layers=(64, 32),
                    batch_size=4096, n_iter=10, verbose=0, lr=3 * 1e-3, multi_num=-1):
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            return yaml.load(f)

    subsets = fss.sample_subsets(k, num)

    _t = time.time()
    res = []

    if multi_num > 0:
        p = Pool(multi_num)
        input_args = _create_args(subsets, n_iter, train_set, valid_set, test_set, [k] + hidden_layers, batch_size,
                                  2000, verbose, 'cuda', lr)
        res = p.map(_test_subset, input_args)
    else:
        for i, subset in enumerate(subsets):
            print(f'prepare {i}/{len(subsets)}')
            result = test_subset(subset=subset, n_iter=n_iter,
                                 train_dataset=train_set, valid_dataset=valid_set,
                                 test_dataset=test_set, hidden_layers=[k] + hidden_layers,
                                 batch_size=batch_size, max_epoch=2000,
                                 device='cuda', verbose=verbose, lr=lr)
            res.append((subset, result))

    with open(save_path, "w") as f:
        yaml.dump({
            "use_time": time.time() - _t,
            "data": change_data(res),
        }, f)
    return res


def create_random_fss(n):
    return FSS(n)


def load_fss(path):
    with open(path, 'rb') as f:
        res = torch.load(f)
    fss = create_random_fss(len(res['score']))
    fss.load_state_dict(res)
    return fss


def load_metafe(dir_name, name):
    cfg.load_config(os.path.join(dir_name, 'config.yaml'))

    metafe_type = cfg.get('globals.metafe_type', 'allfeature')
    if metafe_type == 'allfeature':
        embedding_module = None if not cfg.get('globals.use_embedding', False) else EmbeddingAttention
        if cfg.get('globals.no_meta', False):
            metaFE = NormalFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.metafe.generator_hidden, embedding_module)
        else:
            metaFE = MetaFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.metafe.generator_hidden, embedding_module)
    elif metafe_type == 'k_start':
        embedding_learner_str = cfg.get('globals.enbedding_learner', 'MLP')
        embedding_learner = {
            'MLP': EmbeddingMLP,
        }.get(embedding_learner_str)
        metaFE = KStartMetaFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.globals.select_k, cfg.globals.all_features,
                              cfg.metafe.generator_hidden, embedding_learner)
    else:
        raise ValueError

    with open(os.path.join(dir_name, name), 'rb') as f:
        data = torch.load(f)
    metaFE.load_state_dict(data)
    return metaFE


def metafe_subsets_eval(valid_set, subsets, save_name, metafe, device='cpu'):
    if save_name is not None:
        if os.path.exists(save_name):
            with open(save_name, 'r') as f:
                return yaml.load(f)

    metafe = metafe.to(device)
    y = torch.LongTensor(valid_set.label).to(device)
    x = torch.FloatTensor(valid_set.data).to(device)

    res = []
    _t = time.time()
    for subset in subsets:
        with torch.no_grad():
            y_ = metafe.forward(x, subset)
        y_ = np.argmax(y_.cpu().detach().numpy(), 1)
        f1 = f1_score(y.cpu().numpy(), y_, average='macro')
        res.append((subset, {"f1": f1}))
    result = {
            "use_time": time.time() - _t,
            "data": change_data(res)
        }

    if save_name is not None:
        with open(save_name, 'w') as f:
            yaml.dump(result, f)

    return result


def baseline_subsets_eval(train_set, valid_set, subsets, save_name, model):
    if os.path.exists(save_name):
        with open(save_name, 'r') as f:
            return yaml.load(f)

    res = []
    _t = time.time()
    for i, subset in enumerate(subsets):
        x = train_set.data[:, subset]
        y = train_set.label
        model.fit(x, y)
        x = valid_set.data[:, subset]
        y = valid_set.label
        y_ = model.predict(x)
        res.append((subset, {
            "f1": f1_score(y, y_, average='macro'),
        }))
        print(f'\r{i}/{len(subsets)}   {res[-1]}', end='')

    result = {
        "use_time": time.time() - _t,
        "data": change_data(res)
    }
    with open(save_name, 'w') as f:
        yaml.dump(result, f)
    return result


def KNN_subsets_eval(train_set, valid_set, subsets, save_name):
    return baseline_subsets_eval(train_set, valid_set, subsets, save_name, KNeighborsClassifier(n_neighbors=5))


def DT_subsets_eval(train_set, valid_set, subsets, save_name):
    return baseline_subsets_eval(train_set, valid_set, subsets, save_name, DecisionTreeClassifier())


def LR_subsets_eval(train_set, valid_set, subsets, save_name):
    return baseline_subsets_eval(train_set, valid_set, subsets, save_name, LogisticRegression())


def SVC_subsets_eval(train_set, valid_set, subsets, save_name):
    return baseline_subsets_eval(train_set, valid_set, subsets, save_name, SVC())


def cal_rbo(f1, f1_):
    true_seq = np.argsort([-1 * item for item in f1])
    metafe_seq = np.argsort([-1 * item for item in f1_])
    rbo = RankingSimilarity(true_seq, metafe_seq).rbo()
    return rbo


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     # parser.add_argument('--config', type=str)
#     args = parser.parse_args()
#
#     # cfg.load_config(args.config)
#     # cfg.save(os.path.join(cfg.globals.save_path, 'config.yaml'))
#     # main()
