import numpy as np
import ConfigSpace
import nasbench.api as api
import os
import pickle
import sys
MAX_EDGES = 9
VERTICES = 7

class NAS_Bench(object):
    def __init__(self, data_dir, deterministic=False):
        self.multi_fidelity = False
        self.deterministic = deterministic
        if self.multi_fidelity:
            self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_full.tfrecord'))
        else:
            try:
                self.dataset = pickle.load(open('./nasbench101.pickle', 'rb'))
            except:
                self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'), )
                pickle.dump(self.dataset, open('./nasbench101.pickle', 'wb'))

            # self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'))
        self.X = []
        self.y_valid = []
        self.y_test = []
        self.costs = []

        self.y_star_valid = 0.04944576819737756  # lowest mean validation error
        self.y_star_test = 0.056824247042338016  # lowest mean test error

    def evaluate(self, h, x_scaled):

        #  unscale inputs
        info = self.get_info()
        bnds = info['unscaled_x_bounds']
        r = bnds[:,1]-bnds[:, 0]
        x = r * (x_scaled + 1) / 2 + bnds[:, 0]
        x = x.flatten()

        # categorical variables
        config={}
        x_names_list = info['x_names']
        h_names_list = info['h_names']
        h_choices_list = info['h_choices']


        for i, (h_name, h_choice) in enumerate(zip(h_names_list, h_choices_list)):
            config[h_name] = h_choice[int(h[i])]
        for j, x_name in enumerate(x_names_list):
            if x_name == 'num_edges':
                config[x_name] = int(x[j])
            else:
                config[x_name] = x[j]

        y = self.objective_function(config)

        return y

    def objective_function(self, config, budget=108):
        if self.multi_fidelity is False:
            assert budget == 108

        edge_prob = []
        for i in range(VERTICES * (VERTICES - 1) // 2):
            edge_prob.append(config["edge_%d" % i])

        idx = np.argsort(edge_prob)[::-1][:config["num_edges"]]
        binay_encoding = np.zeros(len(edge_prob))
        binay_encoding[idx] = 1
        matrix = np.zeros([VERTICES, VERTICES], dtype=np.int8)
        idx = np.triu_indices(matrix.shape[0], k=1)
        for i in range(VERTICES * (VERTICES - 1) // 2):
            row = idx[0][i]
            col = idx[1][i]
            matrix[row, col] = binay_encoding[i]

        labeling = [config["op_node_%d" % i] for i in range(5)]
        labeling = ['input'] + list(labeling) + ['output']
        model_spec = api.ModelSpec(matrix, labeling)
        try:
            if self.deterministic is True:
                patience = 50
                accs = []
                while len(accs) < 3 and patience > 0:
                    patience -= 1
                    acc = self.dataset.query(model_spec)['validation_accuracy']
                    if acc not in accs:
                        accs.append(acc)
                err = (1. - np.mean(accs))
            else:
                # This involves some stochasticity
                data = self.dataset.query(model_spec)
                err = 1. - data["validation_accuracy"]

            # data = self.dataset.query(model_spec, epochs=budget)
        except api.OutOfDomainError:
            err = 1.

        return err

    def get_info(self) -> dict:

        cs = self.get_configuration_space()
        info = {}

        h_choices_lenth = []
        h_names = []
        unscaled_x_bounds = []
        x_names = []
        h_choices = []
        for hyper in cs.get_hyperparameters():
            if type(hyper) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                h_choices_lenth.append(len(hyper.choices))
                h_names.append(hyper.name)
                h_choices.append(hyper.choices)

            else:
                unscaled_x_bounds.append([hyper.lower, hyper.upper])
                x_names.append(hyper.name)
        info['h_bounds'] = np.array(h_choices_lenth)
        info['unscaled_x_bounds'] = np.array(unscaled_x_bounds)
        info['x_bounds'] = np.array([[-1, 1]] * len(unscaled_x_bounds))
        info['x_names'] = x_names
        info['h_choices'] = h_choices
        info['h_names'] = h_names

        return info

    @staticmethod
    def get_configuration_space():
        cs = ConfigSpace.ConfigurationSpace()

        ops_choices = ['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3']
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_0", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_1", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_2", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_3", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_4", ops_choices))

        cs.add_hyperparameter(ConfigSpace.UniformIntegerHyperparameter("num_edges", 0, MAX_EDGES))

        for i in range(VERTICES * (VERTICES - 1) // 2):
            cs.add_hyperparameter(ConfigSpace.UniformFloatHyperparameter("edge_%d" % i, 0, 1))
        return cs

    def make_gpyopt_space(self, info: dict=None):
        if info is None:
            info = self.get_info()

        category = list(info['h_bounds'])
        bounds = []

        for ii, h_bound in enumerate(info['h_bounds']):
            bounds.append(
                {'name': f'h{ii + 1}', 'type': 'categorical',
                 'domain': tuple(range(h_bound))}
            )

        for ii, x_bound in enumerate(info['x_bounds']):
            bounds.append(
                {'name': f'x{ii + 1}', 'type': 'continuous', 'domain': x_bound}
            )

        return bounds, category


if __name__ == '__main__':
    nas_bench_cifar10 = NAS_Bench(data_dir='/Users/binxinru/Documents/Ph.D/Projects/MABBO/Multi-Arm-Bandit-BO/testFunctions/tabular_benchmarks/benchmark_data/', deterministic=True)
    np.random.seed(1)
    h = list(np.random.choice(range(3), 5))
    x = np.random.rand(22)
    err = nas_bench_cifar10.evaluate(h, x_scaled=x)
