import nasbench.api as api
from localglobal.nasbench.lib import graph_util
# from localglobal.test_funcs.base import TestFunction
import numpy as np
import os
import copy
import random
import pickle

data_dir = '/mnt/08B82010B81FFAC0/PyCharm Projects/TurBODiscrete/data/'
categorical_mapping = {
    0: 'conv3x3-bn-relu',
    1: 'conv1x1-bn-relu',
    2: 'maxpool3x3'
}
try:
    dataset = pickle.load(open('./data/nasbench101.pickle', 'rb'))
except:
    dataset = api.NASBench(os.path.join(data_dir, 'nasbench_only108.tfrecord'), )
    pickle.dump(dataset, open('./data/nasbench101.pickle', 'wb'))


def transform(ht, x):
    """Transform the input data to a format understood by the nasbench API"""
    ops = ht
    probs = x[:-1]
    # find the indices of the top-k elements, for k being the last element
    k = int(np.round(x[-1]))
    top_indices = np.argpartition(probs, -k)[-k:]
    # Select the top-k probabilities to be 1
    adj_elements = np.zeros(len(x)-1)
    adj_elements[top_indices] = 1
    ops_string = ['input'] + [categorical_mapping[int(x_)] for x_ in ops] + ['output']
    adj_matrix = np.zeros((7, 7))
    idx = np.triu_indices(7, k=1)
    for i in range(7 * 6 // 2):
        row = idx[0][i]
        col = idx[1][i]
        adj_matrix[row, col] = adj_elements[i]
    adj_matrix = adj_matrix.astype(np.int64)
    return ops_string, adj_matrix


def nasbench101(ht_list, X, deterministic=True):
    assert len(ht_list) == 5
    ops, matrix = transform(ht_list, X)
    model_spec = api.ModelSpec(matrix, ops)
    try:
        if deterministic:
            patience = 50
            accs = []
            while len(accs) < 3 and patience > 0:
                patience -= 1
                acc = dataset.query(model_spec)['validation_accuracy']
                if acc not in accs:
                    accs.append(acc)
            err = (1 - np.mean(accs))
        else:
            # This involves some stochasticity
            data = dataset.query(model_spec)
            err = 1. - data["validation_accuracy"]

    except api.OutOfDomainError:
        err = 1.
    y = np.log(err)
    return y