# Modified from Robin's CoCaBO codes

from localglobal.test_funcs.base import TestFunction
import numpy as np
import ConfigSpace
import nasbench.api as api
import os
import pickle
import sys

MAX_EDGES = 9
VERTICES = 7


class NASBench101(TestFunction):
    problem_type = 'mixed'

    def __init__(self, data_dir, deterministic=False):
        super(NASBench101, self).__init__(normalize=False)
        self.multi_fidelity = False
        self.deterministic = deterministic
        if self.multi_fidelity:
            self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_full.tfrecord'))
        else:
            try:
                self.dataset = pickle.load(open('./nasbench101.pickle', 'rb'))
            except:
                self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'), )
                pickle.dump(self.dataset, open('./nasbench101.pickle', 'wb'))

            # self.dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'))
        self.X = []
        self.y_valid = []
        self.y_test = []
        self.costs = []

        self.y_star_valid = 0.04944576819737756  # lowest mean validation error
        self.y_star_test = 0.056824247042338016  # lowest mean test error

        # To conform with our API.
        self.categorical_dims = np.arange(5)
        self.continuous_dims = np.arange(5, 22 + 5)
        self.dim = len(self.categorical_dims) + len(self.continuous_dims)
        self.lb = -1. * np.ones(22)
        self.ub = np.ones(22)
        self.n_vertices = self.config = np.array([3, 3, 3, 3, 3])

    def compute(self, X, normalize=None):
        if X.ndim == 1:
            X = X.reshape(1, -1)
        N = X.shape[0]
        res = np.zeros((N,))
        X_cat = X[:, self.categorical_dims]
        X_cont = X[:, self.continuous_dims]

        for i, X in enumerate(X):
            h = [int(x_j) for x_j in X_cat[i, :]]
            res[i] = self.evaluate(h, X_cont[i, :])
            # if self.fX_lb is not None and res[i] < self.fX_lb:
            #     res[i] = self.fX_lb
            # elif self.fX_ub is not None and res[i] > self.fX_ub:
            #     res[i] = self.fX_ub

        return res

    def evaluate(self, h, x_scaled):

        #  unscale inputs
        info = self.get_info()
        bnds = info['unscaled_x_bounds']
        r = bnds[:, 1] - bnds[:, 0]
        x = r * (x_scaled + 1) / 2 + bnds[:, 0]
        x = x.flatten()
        print(h, x)

        # categorical variables
        config = {}
        x_names_list = info['x_names']
        h_names_list = info['h_names']
        h_choices_list = info['h_choices']

        for i, (h_name, h_choice) in enumerate(zip(h_names_list, h_choices_list)):
            config[h_name] = h_choice[int(h[i])]
        for j, x_name in enumerate(x_names_list):
            if x_name == 'num_edges':
                config[x_name] = int(x[j])
            else:
                config[x_name] = x[j]

        y = self.objective_function(config)

        return y

    def objective_function(self, config, budget=108):
        if self.multi_fidelity is False:
            assert budget == 108

        edge_prob = []
        for i in range(VERTICES * (VERTICES - 1) // 2):
            edge_prob.append(config["edge_%d" % i])

        idx = np.argsort(edge_prob)[::-1][:config["num_edges"]]
        binay_encoding = np.zeros(len(edge_prob))
        binay_encoding[idx] = 1
        matrix = np.zeros([VERTICES, VERTICES], dtype=np.int8)
        idx = np.triu_indices(matrix.shape[0], k=1)
        for i in range(VERTICES * (VERTICES - 1) // 2):
            row = idx[0][i]
            col = idx[1][i]
            matrix[row, col] = binay_encoding[i]

        labeling = [config["op_node_%d" % i] for i in range(5)]
        labeling = ['input'] + list(labeling) + ['output']
        model_spec = api.ModelSpec(matrix, labeling)
        try:
            if self.deterministic is True:
                patience = 50
                accs = []
                while len(accs) < 3 and patience > 0:
                    patience -= 1
                    acc = self.dataset.query(model_spec)['validation_accuracy']
                    if acc not in accs:
                        accs.append(acc)
                err = (1. - np.mean(accs))
            else:
                # This involves some stochasticity
                data = self.dataset.query(model_spec)
                err = 1. - data["validation_accuracy"]

            # data = self.dataset.query(model_spec, epochs=budget)
        except api.OutOfDomainError:
            err = 1.

        # err = np.log(err)
        return err

    def get_info(self) -> dict:

        cs = self.get_configuration_space()
        info = {}

        h_choices_lenth = []
        h_names = []
        unscaled_x_bounds = []
        x_names = []
        h_choices = []
        for hyper in cs.get_hyperparameters():
            if type(hyper) == ConfigSpace.hyperparameters.CategoricalHyperparameter:
                h_choices_lenth.append(len(hyper.choices))
                h_names.append(hyper.name)
                h_choices.append(hyper.choices)

            else:
                unscaled_x_bounds.append([hyper.lower, hyper.upper])
                x_names.append(hyper.name)
        info['h_bounds'] = np.array(h_choices_lenth)
        info['unscaled_x_bounds'] = np.array(unscaled_x_bounds)
        info['x_bounds'] = np.array([[-1, 1]] * len(unscaled_x_bounds))
        info['x_names'] = x_names
        info['h_choices'] = h_choices
        info['h_names'] = h_names

        return info

    @staticmethod
    def get_configuration_space():
        cs = ConfigSpace.ConfigurationSpace()

        ops_choices = ['conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3']
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_0", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_1", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_2", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_3", ops_choices))
        cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter("op_node_4", ops_choices))

        cs.add_hyperparameter(ConfigSpace.UniformIntegerHyperparameter("num_edges", 0, MAX_EDGES))

        for i in range(VERTICES * (VERTICES - 1) // 2):
            cs.add_hyperparameter(ConfigSpace.UniformFloatHyperparameter("edge_%d" % i, 0, 1))
        return cs

    def make_gpyopt_space(self, info: dict = None):
        if info is None:
            info = self.get_info()

        category = list(info['h_bounds'])
        bounds = []

        for ii, h_bound in enumerate(info['h_bounds']):
            bounds.append(
                {'name': f'h{ii + 1}', 'type': 'categorical',
                 'domain': tuple(range(h_bound))}
            )

        for ii, x_bound in enumerate(info['x_bounds']):
            bounds.append(
                {'name': f'x{ii + 1}', 'type': 'continuous', 'domain': x_bound}
            )

        return bounds, category


if __name__ == '__main__':
    nas_bench_cifar10 = NASBench101(
        data_dir='/Users/binxinru/Documents/Ph.D/Projects/MABBO/Multi-Arm-Bandit-BO/testFunctions/tabular_benchmarks/benchmark_data/',
        deterministic=True)
    np.random.seed(1)
    h = list(np.random.choice(range(3), 5))
    x = np.random.rand(22)
    err = nas_bench_cifar10.evaluate(h, x_scaled=x)
