from nasbench import api
from nasbench.lib.model_spec import ModelSpec
from nas_201_api import NASBench201API as API
import nasbench301 as nb
from Graph_GP.graph_kernels import SPKernel
from MIP import MIP_NASBench101, MIP_NASBench201, MIP_NASBench301, get_var_value

from graph import graph

from collections import namedtuple
import numpy as np
import os

class N101():
    def __init__(self, path="data/"):
        ''' NAS benchmark 101
        Args:
            path (str): path to N101 benchmark models directory
        '''
        self.path = path
        self.name = "N101"

        self.api = api.NASBench(self.path + "nasbench_only108.tfrecord")
        self.epoch = 108
        self.Ln = 5
        # self.N = 7
        self.N_max = 7
        self.N_min = 2
        self.E_max = 9

        self.INPUT = 'input'  # 0
        self.OUTPUT = 'output'  # 4
        self.CONV1X1 = 'conv1x1-bn-relu'  # 1
        self.CONV3X3 = 'conv3x3-bn-relu'  # 2
        self.MAXPOOL3X3 = 'maxpool3x3'  # 3
        self.OPS_encoder = {
            self.INPUT: 0,
            self.OUTPUT: 4,
            self.CONV1X1: 1,
            self.CONV3X3: 2,
            self.MAXPOOL3X3: 3
        }
        self.OPS_decoder = {
            0: self.INPUT,
            4: self.OUTPUT,
            1: self.CONV1X1,
            2: self.CONV3X3,
            3: self.MAXPOOL3X3
        }

    def index_sampling(self, idxs):
        # sample graphs from dataset by its index
        index_dict = self.api.index_hash_dict
        hashes = [index_dict[idx] for idx in idxs]
        graph_pool = []
        for hash in hashes:
            adjacency_matrix = self.api.fixed_statistics[hash]['module_adjacency']
            labeling = self.api.fixed_statistics[hash]['module_operations']
            graph_spec = ModelSpec(matrix=adjacency_matrix, ops=labeling)
            graph_pool.append(self.retrieve_graph(graph_spec))
        return graph_pool

    def convert_graph(self, graph):
        # from graph to spec101
        node_labeling = [self.OPS_decoder[l] for l in graph.node_attr]
        adjacency_matrix = np.array(graph.A, dtype=np.int32) - np.eye(graph.N, dtype=np.int32)
        model_spec = ModelSpec(adjacency_matrix, node_labeling)
        return model_spec

    def retrieve_graph(self, spec):
        # from spec101 to graph
        node_labeling = spec.ops
        N = len(node_labeling)
        adjacency_matrix = spec.matrix + np.eye(N, dtype=np.int32)
        node_attr = [self.OPS_encoder[ops] for ops in node_labeling]
        return graph(A=adjacency_matrix, node_attr=node_attr, Ln=self.Ln, edge_attr=None, Le=None)

    def eval(self, graph, score="log-err", task="cifar10-valid", noisy=False):
        model_spec = self.convert_graph(graph)
        fixed_stat, computed_stat = self.api.get_metrics_from_spec(model_spec)
        compute_data_all = computed_stat[self.epoch]
        if noisy:
            compute_data = compute_data_all[np.random.choice([0, 1, 2])]
            val_acc = compute_data['final_validation_accuracy']
        else:
            val_acc = np.mean([cd['final_validation_accuracy'] for cd in compute_data_all])
        test_acc = np.mean([cd['final_test_accuracy'] for cd in compute_data_all])
        if score == "log-err":
            return np.log(1 - val_acc), np.log(1 - test_acc)
        else:
            return val_acc, test_acc

    def get_kernel(self, exp_option=False):
        return SPKernel(trainable_lengthscales=False,
                              trainable_variance=exp_option,
                              trainable_alpha=True,
                              trainable_beta=True,
                              trainable_gamma=False,
                              kernel_type="SP",
                              )

    def get_MIP(self, G, beta_t, GPmodel, N, random=False):
        return MIP_NASBench101(G=G, beta_t=beta_t, GraphGP=GPmodel, N=N, random=random)

    def get_optimized_graph(self, MIPmodel, batch=5, N=7):
        solutions = []
        objs = []
        for i in range(MIPmodel.SolCount):
            MIPmodel.params.SolutionNumber = i
            A = get_var_value(MIPmodel, "A", (N, N))
            Fn = get_var_value(MIPmodel, "Fn", (N, self.Ln))
            node_attr = [Fn[n].index(1) for n in range(N)]
            retrieved_graph = graph(A=A, node_attr=node_attr, Ln=self.Ln, edge_attr=None, Le=None)
            solutions.append(retrieved_graph)
            objs.append(MIPmodel.PoolObjVal)
            if batch == 1:
                break
        return solutions, objs


class N201():
    def __init__(self, path="data/"):
        ''' NAS benchmark 201
        Args:
            path (str): path to N101 benchmark models directory
        '''
        self.path = path
        self.name = "N201"

        self.api = API(self.path + "NAS-Bench-201-v1_1-096897.pth", verbose=False)
        self.epoch = "200"
        self.Le = 5
        self.N = 4
        self.E_max = 6

        self.OPS_encoder = {
            'none': None,
            'skip_connect': 0,
            'nor_conv_1x1': 1,
            'nor_conv_3x3': 2,
            'avg_pool_3x3': 3,
        }
        self.OPS_decoder = {
            None: 'none',
            0: 'skip_connect',
            1: 'nor_conv_1x1',
            2: 'nor_conv_3x3',
            3: 'avg_pool_3x3',
        }

    def index_sampling(self, idxs):
        # sample graphs from dataset by its index
        graph_pool = []
        for idx in idxs:
            arch_str = self.api.arch(idx)
            graph = self.retrieve_graph(arch_str)
            graph_pool.append(graph)
        return graph_pool

    def convert_graph(self, graph):
        # from graph to str201
        ops_ordered_list = [0] * self.E_max
        for v in range(1, self.N):
            for u in range(self.N - 1):
                ops_ordered_list[int(v*(v-1)/2+u)] = self.OPS_decoder[graph.edge_attr[u, v]]
        arch_query_string = f'|{ops_ordered_list[0]}~0|+' \
                            f'|{ops_ordered_list[1]}~0|{ops_ordered_list[2]}~1|+' \
                            f'|{ops_ordered_list[3]}~0|{ops_ordered_list[4]}~1|{ops_ordered_list[5]}~2|'
        return arch_query_string

    def retrieve_graph(self, arch_str):
        # from str201 to graph
        edge_attr = {}
        layers_list = arch_str.split("+")
        for j, l in enumerate(layers_list):
            elems = l.split("|")
            l_len = len(elems) - 2
            for i in range(l_len):
                ops = elems[i + 1][:-2]
                edge_attr[i, j+1] = self.OPS_encoder[ops]
        adj = np.triu(np.ones((self.N, self.N), dtype=np.int32))
        for u in range(self.N):
            for v in range(u+1, self.N):
                if edge_attr[u, v] is None:
                    adj[u, v] = 0
        return graph(A=adj.tolist(), node_attr=None, Ln=None, edge_attr=edge_attr, Le=self.Le-1)

    def eval(self, graph, score="log-err", task="cifar10-valid", noisy=False):
        idx = self.api.query_index_by_arch(self.convert_graph(graph))
        acc_results = self.api.get_more_info(idx, task, None, hp=self.epoch, is_random=noisy)
        val_acc = acc_results['valid-accuracy'] / 100
        acc_test_results = self.api.get_more_info(idx, task, None, hp=self.epoch, is_random=False)
        test_acc = acc_test_results['test-accuracy'] / 100
        if score == ("log-err"):
            return np.log(1 - val_acc), np.log(1 - test_acc)
        else:
            return val_acc, test_acc

    def get_kernel(self, exp_option=False):
        return SPKernel(trainable_lengthscales=False,
                              trainable_variance=exp_option,
                              trainable_alpha=True,
                              trainable_beta=False,
                              trainable_gamma=True,
                              kernel_type="SSP",
                              )

    def get_MIP(self, G, beta_t, GPmodel, random=False):
        return MIP_NASBench201(G=G, beta_t=beta_t, GraphGP=GPmodel, random=random)

    def get_optimized_graph(self, MIPmodel, batch=5):
        solutions = []
        for i in range(MIPmodel.SolCount):
            MIPmodel.params.SolutionNumber = i
            A = get_var_value(MIPmodel, "A", (self.N, self.N))
            Fe = get_var_value(MIPmodel, "Fe", (self.N, self.N, self.Le-1))
            edge_attr = {}
            for u in range(self.N-1):
                for v in range(u+1, self.N):
                    edge_attr[u, v] = Fe[u][v].index(1) if 1 in Fe[u][v] else None
            solutions.append(graph(A=A, node_attr=None, Ln=None, edge_attr=edge_attr, Le=self.Le-1))
            if batch == 1:
                break
        return solutions


class N301():
    def __init__(self, path="data/", surrogate="xgb"):
        ''' NAS benchmark 301
        search space: DARTS
        dataset: CIFAR10
        Args:
            path (str): path to N301 benchmark models directory
            surrogate (str): "xgb" or "gin"
        We consider a subgraph (N=6), excluding the output node from the original graph since the rest of node and edges are fixed.
        '''
        self.path = path
        self.surrogate = surrogate
        self.name = "N301"
        model_dir = os.path.join(self.path, "nb_models_0.9")
        self.model_paths = {
                    model_name : os.path.join(model_dir, '{}_v0.9'.format(model_name))
                    for model_name in ['xgb', 'gnn_gin', 'lgb_runtime']
                }
        # Download automatically if not available
        if not all(os.path.exists(model) for model in self.model_paths.values()):
            nb.download_models(version="0.9", delete_zip=True,
                               download_dir=self.path)

        self.performance_model_path = self.model_paths[self.surrogate]
        self.performance_model = nb.load_ensemble(self.performance_model_path)
        self.runtime_model_path = self.model_paths["lgb_runtime"]
        self.runtime_model = nb.load_ensemble(self.runtime_model_path)

        self.Le = 8
        self.N = 6
        self.E_max = 14

        self.OPS_encoder = {
            'none': None,
            'max_pool_3x3': 0,
            'avg_pool_3x3': 1,
            'skip_connect': 2,
            'sep_conv_3x3': 3,
            'sep_conv_5x5': 4,
            'dil_conv_3x3': 5,
            'dil_conv_5x5': 6,
        }
        self.OPS_decoder = {
            None: 'none',
            0: 'max_pool_3x3',
            1: 'avg_pool_3x3',
            2: 'skip_connect',
            3: 'sep_conv_3x3',
            4: 'sep_conv_5x5',
            5: 'dil_conv_3x3',
            6: 'dil_conv_5x5'
        }

    def separate_to_pair(self, spliced_graph):
        # separate out the graph pairs from a spliced graph
        A = np.array(spliced_graph.A)
        An, Ar = A[:self.N, :self.N], A[self.N+1:-1, self.N+1:-1]
        edge_attr = spliced_graph.edge_attr
        nor_edge_attr, red_edge_attr = {}, {}
        for (u, v), edge_label in edge_attr.items():
            if u < self.N and v < self.N:
                nor_edge_attr[u, v] = edge_label
            elif self.N < u < self.N * 2 + 1 and self.N < v < self.N * 2 + 1:
                red_edge_attr[u-(self.N+1), v-(self.N+1)] = edge_label
        return [graph(A=An.tolist(), node_attr=None, Ln=None, edge_attr=nor_edge_attr, Le=self.Le-1),
                graph(A=Ar.tolist(), node_attr=None, Ln=None, edge_attr=red_edge_attr, Le=self.Le-1)]

    def convert_graph(self, graph_pair):
        # from graph pair to darts genotype
        Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
        normal_graph, reduce_graph = graph_pair[0], graph_pair[1]
        An, nor_edge_attr = np.array(normal_graph.A), normal_graph.edge_attr
        Ar, red_edge_attr = np.array(reduce_graph.A), reduce_graph.edge_attr
        normal_cell, reduce_cell = [], []
        for v in range(2, self.N):
            input_nodes = np.where(An[:, v] == 1)[0].tolist()[:-1]
            for u in input_nodes:
                edge_label = self.OPS_decoder[nor_edge_attr[u, v]]
                normal_cell.append((edge_label, u))
        for v in range(2, self.N):
            input_nodes = np.where(Ar[:, v] == 1)[0].tolist()[:-1]
            for u in input_nodes:
                edge_label = self.OPS_decoder[red_edge_attr[u, v]]
                reduce_cell.append((edge_label, u))
        return Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5])


    def retrieve_graph(self, genotype):
        # from darts genotype to graph
        normal, reduce = genotype.normal, genotype.reduce
        graph_pair = []
        for cell in [normal, reduce]:
            # initialize adjacency matrix and edge attributes
            A = np.eye(self.N)
            edge_attr = {}

            intermediate_node_num = self.N - 2
            for node_idx in range(intermediate_node_num):
                # incoming operations
                ops = cell[node_idx*2:node_idx*2+2]
                for edge_label, input_node_idx in ops:
                    A[input_node_idx, node_idx+2] = 1
                    edge_attr[input_node_idx, node_idx+2] = self.OPS_encoder[edge_label]

            graph_pair.append(graph(A=A.tolist(), node_attr=None, Ln=None, edge_attr=edge_attr, Le=self.Le-1))
        return graph_pair

    def eval(self, graph, score="log-err", noisy=True, runtime=False):
        # evaluate the combined graph and return the corresponding prediction performance
        graph_pair = self.separate_to_pair(graph)
        genotype = self.convert_graph(graph_pair)
        if not runtime:
            acc = self.performance_model.predict(config=genotype, representation="genotype", with_noise=noisy)
            if score == "log-err":
                return np.log(1 - acc/100)
            else:
                return acc/100
        else:
            return self.runtime_model.predict(config=genotype, representation="genotype")

    def get_kernel(self, exp_option=False):
        return SPKernel(trainable_lengthscales=False,
                              trainable_variance=exp_option,
                              trainable_alpha=True,
                              trainable_beta=False,
                              trainable_gamma=True,
                              kernel_type="SSP",
                              )

    def get_MIP(self, G, beta_t, GPmodel, random=False):
        return MIP_NASBench301(G=G, beta_t=beta_t, GraphGP=GPmodel, random=random)

    def get_optimized_graph(self, MIPmodel, batch=5):
        solutions = []
        for i in range(MIPmodel.SolCount):
            MIPmodel.params.SolutionNumber = i
            A = get_var_value(MIPmodel, "A", ((self.N+1)*2, (self.N+1)*2))
            Fe = get_var_value(MIPmodel, "Fe", ((self.N+1)*2, (self.N+1)*2, self.Le-1))
            edge_attr = {}
            for u in range((self.N+1)*2-1):
                for v in range(u+1, (self.N+1)*2):
                    edge_attr[u, v] = Fe[u][v].index(1) if 1 in Fe[u][v] else None
            solutions.append(graph(A=A, node_attr=None, Ln=None, edge_attr=edge_attr, Le=self.Le-1))
            if batch == 1:
                break
        return solutions