from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from nasbench import api as nasbench101api
from nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec
import itertools
import random
import numpy as np
from models.cell_searchs.genotypes import Structure
from copy import deepcopy
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from pycls.models.anynet import AnyNet
from pycls.models.nas.genotypes import GENOTYPES, Genotype
import json
import torch


class Nasbench201:
    def __init__(self, dataset, apiloc):
        self.dataset = dataset
        self.api = API(apiloc, verbose=False)
        self.epochs = '12'
    def get_network(self, uid):
        #config = self.api.get_net_config(uid, self.dataset)
        config = self.api.get_net_config(uid, 'cifar10-valid')
        config['num_classes'] = 1
        network = get_cell_based_tiny_net(config)
        return network
    def __iter__(self):
        for uid in range(len(self)):
            network = self.get_network(uid)
            yield uid, network
    def __getitem__(self, index):
        return index
    def __len__(self):
        return 15625
    def num_activations(self):
        network = self.get_network(0)
        return network.classifier.in_features
    #def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
    #    archinfo = self.api.query_meta_info_by_index(uid)
    #    if (self.dataset == 'cifar10' or traincifar10) and trainval:
    #        #return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=12)['accuracy']
    #        return archinfo.get_metrics('cifar10-valid', 'x-valid', iepoch=12)['accuracy']
    #    elif traincifar10:
    #        return archinfo.get_metrics('cifar10', acc_type, iepoch=12)['accuracy']
    #    else:
    #        return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
    def get_12epoch_accuracy(self, uid, acc_type, trainval, traincifar10=False):
        #archinfo = self.api.query_meta_info_by_index(uid)
        #if (self.dataset == 'cifar10' and trainval) or traincifar10:
        info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
        #else:
        #    info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp=self.epochs, is_random=True)
        return info['valid-accuracy']
    def get_final_accuracy(self, uid, acc_type, trainval):
        #archinfo = self.api.query_meta_info_by_index(uid)
        info = self.api.get_more_info(uid, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, hp='200',
                          is_random=False)
        # if self.dataset == 'cifar10' and trainval:
        #     info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics('cifar10-valid', 'x-valid')
        #     #info = self.api.query_by_index(uid, 'cifar10-valid', hp='200')
        #     #info = self.api.get_more_info(uid, 'cifar10-valid', iepoch=None, hp='200', is_random=True)
        # else:
        #     info = self.api.query_meta_info_by_index(uid, hp='200').get_metrics(self.dataset, acc_type)
            #info = self.api.query_by_index(uid, self.dataset, hp='200')
            #info = self.api.get_more_info(uid, self.dataset, iepoch=None, hp='200', is_random=True)
        return info['test-accuracy']
        #return info['valid-accuracy']
        #if self.dataset == 'cifar10' and trainval:
        #    return archinfo.get_metrics('cifar10-valid', acc_type, iepoch=11)['accuracy']
        #else:
        #    #return archinfo.get_metrics(self.dataset, 'ori-test', iepoch=12)['accuracy']
        #    return archinfo.get_metrics(self.dataset, 'x-test', iepoch=11)['accuracy']
        ##dataset = self.dataset
        ##if self.dataset == 'cifar10' and trainval:
        ##    dataset = 'cifar10-valid'
        ##archinfo = self.api.get_more_info(uid, dataset, iepoch=None, use_12epochs_result=True, is_random=True)
        ##return archinfo['valid-accuracy']

    def get_accuracy(self, uid, acc_type, trainval=True):
        archinfo = self.api.query_meta_info_by_index(uid)
        if self.dataset == 'cifar10' and trainval:
            return archinfo.get_metrics('cifar10-valid', acc_type)['accuracy']
        else:
            return archinfo.get_metrics(self.dataset, acc_type)['accuracy']

    def get_accuracy_for_all_datasets(self, uid):
        archinfo = self.api.query_meta_info_by_index(uid,hp='200')

        c10 = archinfo.get_metrics('cifar10', 'ori-test')['accuracy']
        c10_val = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']

        c100 = archinfo.get_metrics('cifar100', 'x-test')['accuracy']
        c100_val = archinfo.get_metrics('cifar100', 'x-valid')['accuracy']

        imagenet = archinfo.get_metrics('ImageNet16-120', 'x-test')['accuracy']
        imagenet_val = archinfo.get_metrics('ImageNet16-120', 'x-valid')['accuracy']

        return c10, c10_val, c100, c100_val, imagenet, imagenet_val

    #def train_and_eval(self, arch, dataname, acc_type, trainval=True):
    #    unique_hash = self.__getitem__(arch)
    #    time = self.get_training_time(unique_hash)
    #    acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval)
    #    acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
    #    return acc12, acc, time
    def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
        unique_hash = self.__getitem__(arch)
        time = self.get_training_time(unique_hash)
        acc12 = self.get_12epoch_accuracy(unique_hash, acc_type, trainval, traincifar10)
        acc = self.get_final_accuracy(unique_hash, acc_type, trainval)
        return acc12, acc, time
    def random_arch(self):
        return random.randint(0, len(self)-1)
    def get_training_time(self, unique_hash):
        #info = self.api.get_more_info(unique_hash, 'cifar10-valid' if self.dataset == 'cifar10' else self.dataset, iepoch=None, use_12epochs_result=True, is_random=True)


        #info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
        info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp='12', is_random=True)
        return info['train-all-time'] + info['valid-per-time']
        #if self.dataset == 'cifar10' and trainval:
        #    info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, hp=self.epochs, is_random=True)
        #else:
        #    info = self.api.get_more_info(unique_hash, self.dataset, iepoch=None, hp=self.epochs, is_random=True)

        ##info = self.api.get_more_info(unique_hash, 'cifar10-valid', iepoch=None, use_12epochs_result=True, is_random=True)
        #return info['train-all-time'] + info['valid-per-time']
    def mutate_arch(self, arch):
        op_names = get_search_spaces('cell', 'nas-bench-201')
        #config = self.api.get_net_config(arch, self.dataset)
        config = self.api.get_net_config(arch, 'cifar10-valid')
        parent_arch = Structure(self.api.str2lists(config['arch_str']))
        child_arch = deepcopy( parent_arch )
        node_id = random.randint(0, len(child_arch.nodes)-1)
        node_info = list( child_arch.nodes[node_id] )
        snode_id = random.randint(0, len(node_info)-1)
        xop = random.choice( op_names )
        while xop == node_info[snode_id][0]:
          xop = random.choice( op_names )
        node_info[snode_id] = (xop, node_info[snode_id][1])
        child_arch.nodes[node_id] = tuple( node_info )
        arch_index = self.api.query_index_by_arch( child_arch )
        return arch_index

class Nasbench101:
    def __init__(self, dataset, apiloc, args):
        self.dataset = dataset
        self.api = nasbench101api.NASBench(apiloc)
        self.args = args
    def get_accuracy(self, unique_hash, acc_type, trainval=True):
        spec = self.get_spec(unique_hash)
        _, stats = self.api.get_metrics_from_spec(spec)
        maxacc = 0.
        for ep in stats:
            for statmap in stats[ep]:
                newacc = statmap['final_test_accuracy']
                if newacc > maxacc:
                    maxacc = newacc
        return maxacc
    def get_final_accuracy(self, uid, acc_type, trainval):
        return self.get_accuracy(uid, acc_type, trainval)
    def get_training_time(self, unique_hash):
        spec = self.get_spec(unique_hash)
        _, stats = self.api.get_metrics_from_spec(spec)
        maxacc = -1.
        maxtime = 0.
        for ep in stats:
            for statmap in stats[ep]:
                newacc = statmap['final_test_accuracy']
                if newacc > maxacc:
                    maxacc = newacc
                    maxtime = statmap['final_training_time']
        return maxtime
    def get_network(self, unique_hash):
        spec = self.get_spec(unique_hash)
        network = Network(spec, self.args)
        return network
    def get_spec(self, unique_hash):
        matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
        operations = self.api.fixed_statistics[unique_hash]['module_operations']
        spec = ModelSpec(matrix, operations)
        return spec
    def __iter__(self):
        for unique_hash in self.api.hash_iterator():
            network = self.get_network(unique_hash)
            yield unique_hash, network
    def __getitem__(self, index):
        return next(itertools.islice(self.api.hash_iterator(), index, None))
    def __len__(self):
        return len(self.api.hash_iterator())
    def num_activations(self):
        for unique_hash in self.api.hash_iterator():
            network = self.get_network(unique_hash)
            return network.classifier.in_features
    def train_and_eval(self, arch, dataname, acc_type, trainval=True, traincifar10=False):
        unique_hash = self.__getitem__(arch)
        time =12.* self.get_training_time(unique_hash)/108.
        acc = self.get_accuracy(unique_hash, acc_type, trainval)
        return acc, acc, time
    def random_arch(self):
        return random.randint(0, len(self)-1)
    def mutate_arch(self, arch):
        unique_hash = self.__getitem__(arch)
        matrix = self.api.fixed_statistics[unique_hash]['module_adjacency']
        operations = self.api.fixed_statistics[unique_hash]['module_operations']
        coords = [ (i, j) for i in range(matrix.shape[0]) for j in range(i+1, matrix.shape[1])]
        random.shuffle(coords)
        # loop through changes until we find change thats allowed
        for i, j in coords:
            # try the ops in a particular order
            for k in [m for m in np.unique(matrix) if m != matrix[i, j]]:
                newmatrix = matrix.copy()
                newmatrix[i, j] = k
                spec = ModelSpec(newmatrix, operations)
                try:
                    newhash = self.api._hash_spec(spec)
                    if newhash in self.api.fixed_statistics:
                        return [n for n, m in enumerate(self.api.fixed_statistics.keys()) if m == newhash][0]
                except:
                    pass




class ReturnFeatureLayer(torch.nn.Module):
    def __init__(self, mod):
        super(ReturnFeatureLayer, self).__init__()
        self.mod = mod
    def forward(self, x):
        return self.mod(x), x
                

def return_feature_layer(network, prefix=''):
    #for attr_str in dir(network):
    #    target_attr = getattr(network, attr_str)
    #    if isinstance(target_attr, torch.nn.Linear):
    #        setattr(network, attr_str, ReturnFeatureLayer(target_attr))
    for n, ch in list(network.named_children()):
        if isinstance(ch, torch.nn.Linear):
            setattr(network, n, ReturnFeatureLayer(ch))
        else:
            return_feature_layer(ch, prefix + '\t')
             

class NDS:
    def __init__(self, searchspace):
        self.searchspace = searchspace
        data = json.load(open(f'nds_data/{searchspace}.json', 'r'))
        try:
            data = data['top'] + data['mid']
        except Exception as e:
            pass
        self.data = data
    def __iter__(self):
        for unique_hash in range(len(self)):
            network = self.get_network(unique_hash)
            yield unique_hash, network
    def get_network_config(self, uid):
        return self.data[uid]['net']
    def get_network_optim_config(self, uid):
        return self.data[uid]['optim']
    def get_network(self, uid):
        netinfo = self.data[uid]
        config = netinfo['net']
        #print(config)
        if 'genotype' in config:
            #print('geno')
            gen = config['genotype']
            genotype = Genotype(normal=gen['normal'], normal_concat=gen['normal_concat'], reduce=gen['reduce'], reduce_concat=gen['reduce_concat'])
            if '_in' in self.searchspace:
                network = NetworkImageNet(config['width'], 1, config['depth'], config['aux'],  genotype)
            else:
                network = NetworkCIFAR(config['width'], 10, config['depth'], config['aux'],  genotype)
            network.drop_path_prob = 0.
            #print(config)
            #print('genotype')
            L = config['depth']
        else:
            if 'bot_muls' in config and 'bms' not in config:
                config['bms'] = config['bot_muls']
                del config['bot_muls']
            if 'num_gs' in config and 'gws' not in config:
                config['gws'] = config['num_gs']
                del config['num_gs']
            config['nc'] = 1
            config['se_r'] = None
            config['stem_w'] = 12
            L = sum(config['ds'])
            if 'ResN' in self.searchspace:
                config['stem_type'] = 'res_stem_in'
            else:
                config['stem_type'] = 'simple_stem_in'
            #"res_stem_cifar": ResStemCifar,
            #"res_stem_in": ResStemIN,
            #"simple_stem_in": SimpleStemIN,
            if config['block_type'] == 'double_plain_block':
                config['block_type'] = 'vanilla_block'
            network = AnyNet(**config)
        # return_feature_layer(network)
        return network
    def __getitem__(self, index):
        return index
    def __len__(self):
        return len(self.data)
    def random_arch(self):
        return random.randint(0, len(self.data)-1)
    def get_final_accuracy(self, uid, acc_type, trainval):
        return 100.-self.data[uid]['test_ep_top1'][-1]


def get_search_space(args):
    if args.nasspace == 'nasbench201':
        return Nasbench201(args.dataset, args.api_loc)
    elif args.nasspace == 'nasbench101':
        return Nasbench101(args.dataset, args.api_loc, args)
    elif args.nasspace == 'nds_resnet':
        return NDS('ResNet')
    elif args.nasspace == 'nds_amoeba':
        return NDS('Amoeba')
    elif args.nasspace == 'nds_amoeba_in':
        return NDS('Amoeba_in')
    elif args.nasspace == 'nds_darts_in':
        return NDS('DARTS_in')
    elif args.nasspace == 'nds_darts':
        return NDS('DARTS')
    elif args.nasspace == 'nds_darts_fix-w-d':
        return NDS('DARTS_fix-w-d')
    elif args.nasspace == 'nds_darts_lr-wd':
        return NDS('DARTS_lr-wd')
    elif args.nasspace == 'nds_enas':
        return NDS('ENAS')
    elif args.nasspace == 'nds_enas_in':
        return NDS('ENAS_in')
    elif args.nasspace == 'nds_enas_fix-w-d':
        return NDS('ENAS_fix-w-d')
    elif args.nasspace == 'nds_pnas':
        return NDS('PNAS')
    elif args.nasspace == 'nds_pnas_fix-w-d':
        return NDS('PNAS_fix-w-d')
    elif args.nasspace == 'nds_pnas_in':
        return NDS('PNAS_in')
    elif args.nasspace == 'nds_nasnet':
        return NDS('NASNet')
    elif args.nasspace == 'nds_nasnet_in':
        return NDS('NASNet_in')
    elif args.nasspace == 'nds_resnext-a':
        return NDS('ResNeXt-A')
    elif args.nasspace == 'nds_resnext-a_in':
        return NDS('ResNeXt-A_in')
    elif args.nasspace == 'nds_resnext-b':
        return NDS('ResNeXt-B')
    elif args.nasspace == 'nds_resnext-b_in':
        return NDS('ResNeXt-B_in')
    elif args.nasspace == 'nds_vanilla':
        return NDS('Vanilla')
    elif args.nasspace == 'nds_vanilla_lr-wd':
        return NDS('Vanilla_lr-wd')
    elif args.nasspace == 'nds_vanilla_lr-wd_in':
        return NDS('Vanilla_lr-wd_in')

