import nasbench.api as api
from localglobal.nasbench.lib import graph_util
from localglobal.test_funcs.base import TestFunction
import numpy as np
import os
import copy
import random


categorical_mapping = {
    0: 'conv3x3-bn-relu',
    1: 'conv1x1-bn-relu',
    2: 'maxpool3x3'
}


class NASBench101(TestFunction):

    problem_type = 'mixed'

    def __init__(self, data_dir, seed=3, normalize=False, which='eval', log_scale=True, negative=False):
        super(NASBench101, self).__init__(normalize)
        self.seed = seed
        self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'), seed=seed)
        self.config = self.n_vertices = np.array([3] * 5)
        self.continuous_dims = np.arange(5, 5+22)
        self.categorical_dims = np.arange(5)
        # The last dimension is the number of edges
        self.int_constrained_dims = np.array([26])
        self.which = which
        self.log_scale, self.negative = log_scale, negative
        self.lb = np.hstack((np.zeros((21, )), np.array(1)))
        self.ub = np.hstack((np.ones((21, )), np.array(9)))

    def compute(self, x, normalize=None):
        if x.ndim == 1:
            x = x.reshape(1, -1)
        res = []
        for x_ in x:
            ops, matrix = self.transform(x_)
            res.append(self.evaluate_single_hyperparameter(ops, matrix))
        return np.array(res)

    def evaluate_single_hyperparameter(self, ops, matrix, ):
        if self.which == 'eval':
            seed_list = [0, 1, 2]
            if self.seed is None:
                seed = random.choice(seed_list)
            elif self.seed >= 3:
                seed = self.seed
            else:
                seed = seed_list[self.seed]
        else:
            # For testing, there should be no stochasticity
            seed = 3
        model_spec = api.ModelSpec(matrix, ops)
        # model_spec = ModelSpec_Modified(matrix, ops)

        try:
            if seed is not None and seed >= 3:
                patience = 50
                accs = []
                while len(accs) < 3 and patience > 0:
                    patience -= 1
                    if self.which == 'eval':
                        acc = self.dataset.query(model_spec)['validation_accuracy']
                    else:
                        acc = self.dataset.query(model_spec)['test_accuracy']
                    if acc not in accs:
                        accs.append(acc)
                err = (1 - np.mean(accs))
            else:
                data = self.dataset.query(model_spec)
                if self.which == 'eval':
                    err = 1. - data["validation_accuracy"]
                else:
                    err = 1. - data['test_accuracy']
        except api.OutOfDomainError:
            err = 1.

        if self.log_scale:
            y = np.log(err)
        else:
            y = err
        if self.negative:
            y = -y
        return y

    def transform(self, x):
        """Transform the input data to a format understood by the nasbench API"""
        assert len(x) == len(self.categorical_dims) + len(self.continuous_dims),\
            "expected a 27 dimensional vector, but got % dimensions" % len(x)
        ops = x[self.categorical_dims]
        probs = x[self.continuous_dims][:-1]
        # find the indices of the top-k elements, for k being the last element
        k = int(np.round(x[-1]))
        top_indices = np.argpartition(probs, -k)[-k:]
        # Select the top-k probabilities to be 1
        adj_elements = np.zeros(len(self.continuous_dims))
        adj_elements[top_indices] = 1
        ops_string = ['input'] + [categorical_mapping[int(x_)] for x_ in ops] + ['output']
        adj_matrix = np.zeros((7, 7))
        idx = np.triu_indices(7, k=1)
        for i in range(7 * 6 // 2):
            row = idx[0][i]
            col = idx[1][i]
            adj_matrix[row, col] = adj_elements[i]
        adj_matrix = adj_matrix.astype(np.int64)
        return ops_string, adj_matrix


if __name__ == '__main__':
    import pickle
    x_cat = np.array([1, 0, 1, 2, 1])
    x_cont = np.random.rand(21)
    n_edges = np.array(9)
    x = np.hstack((x_cat, x_cont, n_edges))
    # f = NASBench101(data_dir='../data/')
    try:
        f = pickle.load(open('../data/nasbench101.pickle', 'rb'))
    except:
        f = NASBench101(data_dir='../data/')
        pickle.dump(f, open('../data/nasbench101.pickle', 'wb'))
    f.compute(x)


