import os
import sys
import time
sys.path.append('../')
import random
import numpy as np
import torch
import torch.nn as nn
from src.config import config as cfg

from src.meta_learning.metafs3 import FeatureSelection, FSS
from src.meta_learning.model import NormalFE, MetaFE, EmbeddingAttention
from src.meta_learning.model2 import MetaFE as KStartMetaFE, EmbeddingMLP
from trainer.utils import DatasetShell, SubShell
from src.util.utils import DataLoaderSampler, change_data
from torch.utils.data import DataLoader
from src.dataloaders import SelectedDataset
from src.verify.verify import verification, ntimes_verification, test_subset
from src.verify.trainer.utils.setshell import DatasetShell
from src.util.utils import append_to_txt
import argparse
import ruamel.yaml as yaml
from src.config import ResultLogging
import copy


device = None
train_set = None
valid_set = None
test_set = None


def init_sys():
    global device
    if cfg.sys.device == 'auto':
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        #         device = torch.device(cfg.sys.device)
        device = torch.device('cuda')

    random.seed(cfg.sys.seed)
    np.random.seed(cfg.sys.seed)
    torch.manual_seed(cfg.sys.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_dataset_class(dataset: str):
    if dataset == 'P53':
        from src.dataloaders.p53 import P53
        return P53
    elif dataset == 'QSAR':
        from src.dataloaders.qsar import QSAR
        return QSAR
    elif dataset == 'Gisette':
        from src.dataloaders.gisette import Gisette
        return Gisette
    elif dataset == 'Shopping':
        from src.dataloaders.shopping import ShoppingGender
        return ShoppingGender
    else:
        raise FileNotFoundError(f'No dataset named {dataset}')


def load_dataset(dataset_class, max_sample_size=None):
    _train_set, _test_set = dataset_class('train', max_sample_size=max_sample_size), \
                            dataset_class('test', max_sample_size=max_sample_size)
    try:
        _valid_set = dataset_class(['valid1', 'valid2'], max_sample_size=max_sample_size)
    except Exception:
        _valid_set = dataset_class('valid', max_sample_size=max_sample_size)
    return _train_set, _valid_set, _test_set


def main(dir_name=None):
    init_sys()
    dataset_class = load_dataset_class(cfg.dataset)
    global train_set, valid_set, test_set

    train_set, valid_set, test_set = load_dataset(dataset_class, cfg.globals.max_sample_size)

    # support_set, target_set = SubShell(train_set, valid_set).shuffle().cut(0.8)

    if not os.path.exists(cfg.globals.save_path):
        os.makedirs(cfg.globals.save_path)

    metafe_type = cfg.get('globals.metafe_type', 'allfeature')
    if metafe_type == 'allfeature':
        embedding_module = None if not cfg.get('globals.use_embedding', False) else EmbeddingAttention
        if cfg.get('globals.no_meta', False):
            metaFE = NormalFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.metafe.generator_hidden, embedding_module)
        else:
            metaFE = MetaFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.metafe.generator_hidden, embedding_module)
    elif metafe_type == 'k_start':
        embedding_learner_str = cfg.get('globals.enbedding_learner', 'MLP')
        embedding_learner = {
            'MLP': EmbeddingMLP,
        }.get(embedding_learner_str)
        metaFE = KStartMetaFE(cfg.metafe.dims, cfg.metafe.embedding_dim, cfg.globals.select_k, cfg.globals.all_features,
                              cfg.metafe.generator_hidden, embedding_learner)
    else:
        raise ValueError

    print(cfg.get('globals.no_meta'), metaFE)
    metaFE.to(device)

    start_time = time.time()

    model = FeatureSelection(metaFE, k=cfg.globals.select_k, n=cfg.globals.all_features,
                             train_dataset=train_set, valid_dataset=valid_set,
                             hold_subsets_num=cfg.globals.hold_subsets_num,
                             sample_batch_size=cfg.weight_train.sample_batch_size,
                             max_batch_size=cfg.globals.max_batch_size,
                             weight_update_batch_num=cfg.weight_train.weight_update_batch_num,
                             feat_sample_batch_size=cfg.feat_train.feat_sample_batch_size,
                             feat_update_batch_num=cfg.feat_train.feat_update_batch_num,
                             weight_step_lr=cfg.weight_train.step_lr,
                             feature_step_lr=cfg.feat_train.step_lr,
                             device=device, LOG_DIR=cfg.globals.save_path)

    # Step 1: pre-train
    if dir_name is not None and os.path.exists(os.path.join(dir_name, 'metafe_model_pre_train_fss.pth')):
        model.load("model_pre_train_fss.pth")
        print('loaded pre-train model.')
    else:
        print('start pre-train')
        model.weight_step(max_iter=cfg.training.pre_train.max_iter,
                          check_period=cfg.training.pre_train.check_period,
                          weight_early_stop_patience=cfg.training.pre_train.early_stop.patience,
                          weight_early_stop_filter=cfg.training.pre_train.early_stop.filter)

        append_to_txt(os.path.join(cfg.globals.save_path, "log_msg.log"),
                      f'[pre-train] <finish>  using time:{time.time()-start_time}')

        model.save('model_pre_train_fss.pth')

    pretrain_metafe_paramters = copy.deepcopy(model.metafe.state_dict())

    # Step 2: iteration
    print('start iteration')

    if cfg.get('training.iters.changeable_iter', False):
        print(cfg.get('training.iters.changeable_fss'))
        _iteration = enumerate(cfg.get('training.iters.changeable_fss'))
        _iter_len = len(cfg.get('training.iters.changeable_fss'))
    else:
        _iteration = enumerate(range(cfg.training.iters.global_train_iter))
        _iter_len = cfg.training.iters.global_train_iter

    for i, _fss_iters in _iteration:
        if cfg.get('training.iters.changeable_iter', False):
            res = model.feat_step(max_iter=_fss_iters,
                                  top_k_p=cfg.training.iters.fss.end_top_k,
                                  search_max_iter=cfg.training.iters.fss.search_max_iter,
                                  prefix=f'global:{i}/{_iter_len}  ',
                                  search=(i >= cfg.training.iters.ignore_search_iter_num))
        else:
            res = model.feat_step(max_iter=cfg.training.iters.fss.train_iter,
                                  top_k_p=cfg.training.iters.fss.end_top_k,
                                  search_max_iter=cfg.training.iters.fss.search_max_iter,
                                  prefix=f'global:{i}/{_iter_len}  ',
                                  search=(i >= cfg.training.iters.ignore_search_iter_num))
        if res:
            model.save(f'model_iter_{i}_final.pth')
            break

        if i >= cfg.training.iters.ignore_search_iter_num:
            model.search_step(max_iter=cfg.training.iters.search.max_iter,
                              search_max_iter=cfg.training.iters.search.search_max_iter)

        if cfg.get('training.iters.metafe.train_from_pre', False):
            model.metafe.load_state_dict(pretrain_metafe_paramters)

        model.weight_step(max_iter=cfg.training.iters.metafe.train_iter,
                          check_period=cfg.training.iters.metafe.check_period,
                          weight_early_stop_patience=cfg.training.iters.metafe.early_stop.patience,
                          weight_early_stop_filter=cfg.training.iters.metafe.early_stop.filter,
                          prefix=f'global:{i}/{_iter_len}  ')

        model.save(f'model_iter_{i}.pth')

    append_to_txt(os.path.join(cfg.globals.save_path, "log_msg.log"),
                  f'[iter] <finish>  using time:{time.time()-start_time}')

    # Step 3: final search
    print('start final search')
    model.search_step(max_iter=cfg.training.final_search.max_iter,
                      search_max_iter=cfg.training.final_search.search_max_iter)

    append_to_txt(os.path.join(cfg.globals.save_path, "log_msg.log"),
                  f'[search] <finish>  using time:{time.time()-start_time}')

    # Step 4: evaluate searching result
    print('start evaluate')
    subsets = model.best_select.data
    subsets = sorted([(subset, score) for subset, score in subsets.items()], key=lambda x: x[1], reverse=True)
    train_set, valid_set, test_set = load_dataset(dataset_class, cfg.verify.max_sample_size)

    results = []

    for i, (subset, score) in enumerate(subsets):
        print(f'testing: {i}/{len(subsets)} subset score: {score}')
        _t = time.time()
        result = test_subset(subset=subset, n_iter=cfg.verify.times,
                             train_dataset=train_set, valid_dataset=valid_set,
                             test_dataset=test_set, hidden_layers=[cfg.globals.select_k] + cfg.metafe.dims[1:],
                             batch_size=cfg.globals.max_batch_size, max_epoch=2000,
                             device=device)
        results.append({
            'subset': subset,
            'metaFE_score': score,
            'true_results': result,
        })

        with open(os.path.join(cfg.globals.save_path, "select_res.yaml"), 'w') as f:
            yaml.dump(change_data(results), f)

    best_subset = max(results, key=lambda d: d['true_results']['val_f1'])

    with open(os.path.join(cfg.globals.save_path, "select_res_final.yaml"), 'w') as f:
        yaml.dump({
                'best_subset': best_subset,
                'data': change_data(results),
            }, f)

    append_to_txt(os.path.join(cfg.globals.save_path, "log_msg.log"),
                  f'[evaluate] <finish>  using time:{time.time()-start_time}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str)
    parser.add_argument('--dir', type=str, default='')
    args = parser.parse_args()

    if args.dir == '':
        cfg.load_config(args.config)
        cfg.save(os.path.join(cfg.globals.save_path, 'config.yaml'))
        main()
    else:
        cfg.load_config(os.path.join(args.dir, 'config.yaml'))
        main(args.dir)
