import torch

from nes.ensemble_selection.containers import Baselearner
from nes.ensemble_selection.utils import model_seeds, create_dataloader_dict_cifar, create_dataloader_dict_fmnist, create_dataloader_dict_imagenet
from nes.optimizers.baselearner_train.train import DARTSByGenotype

# nb201 specific
from nes.utils.nb201.models import get_cell_based_tiny_net, CellStructure
from nes.utils.nb201.config_utils import dict2config
from nes.utils.nb201.api_utils import ResultsCount


def load_nn_module(state_dict_dir, genotype, init_seed=0, dataset='fmnist'):
    # init_seed only used then genotype is None, i.e. when querying nasbench201

    if genotype is not None:
        # seed_init can be anything because we will load a state_dict 
        # anyway, so initialization doesn't matter.
        model = DARTSByGenotype(genotype=genotype, seed_init=0,
                                dataset=dataset)
        model.load_state_dict(torch.load(state_dict_dir))
    else:
        # extract the predictions from the nasbench201 checkpoints
        dataset_to_nb201_dict = {
            'cifar10': 'cifar10-valid',
            'cifar100': 'cifar100',
            'imagenet': 'ImageNet16-120',
        }
        assert dataset in dataset_to_nb201_dict.keys()

        seed_list = [777, 888, 999]
        if init_seed == 0: seed = 777
        elif init_seed == 1: seed = 888
        else: seed = 999

        xdata = torch.load(state_dict_dir)

        try:
            odata = xdata['full']['all_results'][(dataset_to_nb201_dict[dataset],
                                                  seed)]
        except KeyError:
            seed_list.remove(seed)
            seed = seed_list[0]
            try:
                odata = xdata['full']['all_results'][(dataset_to_nb201_dict[dataset],
                                                      seed)]
            except KeyError:
                seed = seed_list[1]
                odata = xdata['full']['all_results'][(dataset_to_nb201_dict[dataset],
                                                      seed)]

        result = ResultsCount.create_from_state_dict(odata)
        result.get_net_param()
        arch_config = result.get_config(CellStructure.str2structure)
        net_config = dict2config(arch_config, None)
        model = get_cell_based_tiny_net(net_config)
        model.load_state_dict(result.get_net_param())

    return model


def create_baselearner(state_dict_dir, genotype, arch_seed, init_seed, scheme,
                       dataset, device, save_dir):
    """
    A function which wraps an nn.Module with the Baselearner container, computes
    predictions and evaluations and finally saves everything.
    """
    assert dataset in ["cifar10", "cifar100", "fmnist", "imagenet"]
    nb201 = True if genotype is None else False

    model_nn = load_nn_module(state_dict_dir, genotype, init_seed, dataset=dataset)

    severities = range(6) if (dataset in ["cifar10", "cifar100"]) else range(1)

    model_id = model_seeds(arch_seed, init_seed, scheme)
    baselearner = Baselearner(model_id=model_id, severities=severities, device=torch.device('cpu'),
                              nn_module=model_nn)

    # Load dataloaders (val, test, all severities) to make predictions on
    if dataset == 'fmnist':
        dataloaders = create_dataloader_dict_fmnist(device)
    elif dataset in ['cifar10', 'cifar100']:
        dataloaders = create_dataloader_dict_cifar(device, dataset, nb201)
    elif dataset == 'imagenet':
        dataloaders = create_dataloader_dict_imagenet(device, dataset, nb201)

    if dataset == 'imagenet':
        num_classes = 120
    else:
        num_classes = 100 if dataset == 'cifar100' else 10

    baselearner.to_device(device)
    baselearner.compute_preds(dataloaders, severities, num_classes=num_classes)
    baselearner.compute_evals(severities)

    # saves the model_id, nn_module and preds & evals.
    baselearner.save(save_dir)

    return baselearner
