from __future__ import annotations
import sys
import os
os.environ['GPFLOW_FLOAT'] = 'float32'
os.environ["OMP_NUM_THREADS"] = "10" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "10" # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "10" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "10" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "10" # export NUMEXPR_NUM_THREADS=6




DIM = None
BATCH_V = 1
NUM_EPISODES = 5
EPOCH_LEN = 500
scenario = 'manyagent_swimmer'
agent_conf = '6x1'


import tensorflow as tf


from bo_lib import *
from scipy_fast_optimizer import Scipy_fast
import pickle

import numpy as np
import trieste
import gpflow
import networkit
from trieste.objectives.utils import mk_observer
from trieste.space import Box
from trieste.data import Dataset
from trieste.models.gpflow import GaussianProcessRegression, build_gpr, SparseVariational
from trieste.models.gpflow.builders import _get_data_stats, _get_mean_function, _set_gaussian_likelihood_variance, _get_inducing_points
from trieste.acquisition.rule import EfficientGlobalOptimization
import pickle
from trieste.models import ProbabilisticModel
from trieste.acquisition.interface import AcquisitionFunction, SingleModelAcquisitionBuilder, SingleModelVectorizedAcquisitionBuilder, AcquisitionFunctionClass
import scipy
from trieste.models.gpflow.inducing_point_selectors import KMeansInducingPointSelector, RandomSubSampleInducingPointSelector

from typing import Any, Callable, List, Optional, Tuple, Union, cast
from trieste.types import TensorType
import tensorflow_probability as tfp
from trieste.models.optimizer import Optimizer, BatchOptimizer

from typing import TypeVar, Union, Tuple, Callable, Optional, Any
import subprocess
import time
import botorch
import glob
import pickle
import networkit
import gpytorch
import datetime
import sys

from scipy_fast_optimizer import *

@tf.function
def _objective_value_real(target_func, vectorized_x: TensorType) -> TensorType:  # [N, D] -> [1]
    vectorized_x=tf.cast(vectorized_x, dtype=tf.float32)
    x = tf.reshape(vectorized_x, [-1, BATCH_V, DIM])  # [N/V, V, D]
    evals = -target_func(x)  # [N/V, V]
    vectorized_evals = tf.reshape(evals, [-1, 1])  # [N, 1]
    return vectorized_evals

@tf.function
def _objective_value(target_func, vectorized_x: TensorType) -> TensorType:  # [N, D] -> [1]
    vectorized_x=tf.cast(vectorized_x, dtype=tf.float32)
    x = tf.reshape(vectorized_x, [-1, BATCH_V, DIM])  # [N/V, V, D]
    evals = -target_func(x)  # [N/V, V]
    vectorized_evals = tf.reshape(evals, [-1, 1])  # [N, 1]
    return tf.math.reduce_mean(vectorized_evals)


@tf.function
def _objective_value_gradient(target_func, vectorized_x: TensorType) -> TensorType:  # [N, D] -> [1]
    vectorized_x=tf.cast(vectorized_x, dtype=tf.float32)
    x = tf.reshape(vectorized_x, [-1, BATCH_V, DIM])  # [N/V, V, D]
    evals = -target_func(x)  # [N/V, V]
    vectorized_evals = tf.reshape(evals, [-1, 1])  # [N, 1]
    return tf.gradients(tf.math.reduce_mean(vectorized_evals), vectorized_x)[0]

def _perform_parallel_continuous_optimization(
    target_func: AcquisitionFunction,
    space: SearchSpace,
    starting_points: TensorType,
    optimizer_args: dict[str, Any],
) -> Tuple[TensorType, TensorType, TensorType]:
    """
    A function to perform parallel optimization of our acquisition functions
    using Scipy. We perform L-BFGS-B starting from each of the locations contained
    in `starting_points`, i.e. the number of individual optimization runs is
    given by the leading dimension of `starting_points`.
    To provide a parallel implementation of Scipy's L-BFGS-B that can leverage
    batch calculations with TensorFlow, this function uses the Greenlet package
    to run each individual optimization on micro-threads.
    L-BFGS-B updates for each individual optimization are performed by
    independent greenlets working with Numpy arrays, however, the evaluation
    of our acquisition function (and its gradients) is calculated in parallel
    (for each optimization step) using Tensorflow.
    For :class:'TaggedProductSearchSpace' we only apply gradient updates to
    its :class:'Box' subspaces, fixing the discrete elements to the best values
    found across the initial random search. To fix these discrete elements, we
    optimize over a continuous :class:'Box' relaxation of the discrete subspaces
    which has equal upper and lower bounds, i.e. we specify an equality constraint
    for this dimension in the scipy optimizer.
    This function also support the maximization of vectorized target functions (with
    vectorization V).
    :param target_func: The function(s) to maximise, with input shape [..., V, D] and
        output shape [..., V].
    :param space: The original search space.
    :param starting_points: The points at which to begin our optimizations of shape
        [num_optimization_runs, V, D]. The leading dimension of
        `starting_points` controls the number of individual optimization runs
        for each of the V target functions.
    :param optimizer_args: Keyword arguments to pass to the Scipy optimizer.
    :return: A tuple containing the failure status, the maximum value
        and the maximiser found my each of our optimziations.
    """

    tf_dtype = starting_points.dtype  # type for communication with Trieste

    num_optimization_runs_per_function = tf.shape(starting_points)[0].numpy()
    V = tf.shape(starting_points)[-2].numpy()  # vectorized batch size
    D = tf.shape(starting_points)[-1].numpy()  # search space dimension
    num_optimization_runs = num_optimization_runs_per_function * V

    vectorized_starting_points = tf.reshape(
        starting_points, [-1]
    )  # [num_optimization_runs*V, D]


    lb = list(space.lower.numpy()) * num_optimization_runs
    ub = list(space.upper.numpy()) * num_optimization_runs
    bounds = scipy.optimize.Bounds(lb, ub)
        
    spo_out = scipy.optimize.minimize(
        lambda x: _objective_value(target_func, x).numpy().astype(np.float64),
        vectorized_starting_points,
        jac=lambda x: _objective_value_gradient(target_func, x).numpy().astype(np.float64),
        method="L-BFGS-B",
        bounds=bounds,
        **optimizer_args,
    )
    successes = tf.constant([spo_out.success] * num_optimization_runs)
    chosen_x = tf.reshape(tf.constant(spo_out.x, dtype=tf.float32), [-1, D])
    fun_values = -1.*_objective_value_real(target_func, chosen_x)
    

    successes = tf.reshape(successes, [-1, V])  # [num_optimization_runs, V]
    fun_values = tf.reshape(fun_values, [-1, V])  # [num_optimization_runs, V]
    chosen_x = tf.reshape(chosen_x, [-1, V, D])  # [num_optimization_runs, V, D]

    return (successes, fun_values, chosen_x)

def generate_continuous_optimizer(
    num_initial_samples: int = 300,
    num_optimization_runs: int = 10,
    num_recovery_runs: int = 10,
    optimizer_args: dict[str, Any] = dict(),
) -> AcquisitionOptimizer[Box | TaggedProductSearchSpace]:
    """
    Generate a gradient-based optimizer for :class:'Box' and :class:'TaggedProductSearchSpace'
    spaces and batches of size 1. In the case of a :class:'TaggedProductSearchSpace', We perform
    gradient-based optimization across all :class:'Box' subspaces, starting from the best location
    found across a sample of `num_initial_samples` random points.
    We advise the user to either use the default `NUM_SAMPLES_MIN` for `num_initial_samples`, or
    `NUM_SAMPLES_DIM` times the dimensionality of the search space, whichever is greater.
    Similarly, for `num_optimization_runs`, we recommend using `NUM_RUNS_DIM` times the
    dimensionality of the search space.
    This optimizer uses Scipy's L-BFGS-B optimizer. We run `num_optimization_runs` separate
    optimizations in parallel, each starting from one of the best `num_optimization_runs` initial
    query points.
    If all `num_optimization_runs` optimizations fail to converge then we run
    `num_recovery_runs` additional runs starting from random locations (also ran in parallel).
    :param num_initial_samples: The size of the random sample used to find the starting point(s) of
        the optimization.
    :param num_optimization_runs: The number of separate optimizations to run.
    :param num_recovery_runs: The maximum number of recovery optimization runs in case of failure.
    :param optimizer_args: The keyword arguments to pass to the Scipy L-BFGS-B optimizer.
        Check `minimize` method  of :class:`~scipy.optimize` for details of which arguments
        can be passed. Note that method, jac and bounds cannot/should not be changed.
    :return: The acquisition optimizer.
    """
    if num_initial_samples <= 0:
        raise ValueError(f"num_initial_samples must be positive, got {num_initial_samples}")

    if num_optimization_runs < 0:
        raise ValueError(f"num_optimization_runs must be positive, got {num_optimization_runs}")

    if num_initial_samples < num_optimization_runs:
        raise ValueError(
            f"""
            num_initial_samples {num_initial_samples} must be at
            least num_optimization_runs {num_optimization_runs}
            """
        )

    if num_recovery_runs <= -1:
        raise ValueError(f"num_recovery_runs must be zero or greater, got {num_recovery_runs}")

    def optimize_continuous(
        space: Box | TaggedProductSearchSpace,
        target_func: Union[AcquisitionFunction, Tuple[AcquisitionFunction, int]],
    ) -> TensorType:
        """
        A gradient-based :const:`AcquisitionOptimizer` for :class:'Box'
        and :class:`TaggedProductSearchSpace' spaces.
        For :class:'TaggedProductSearchSpace' we only apply gradient updates to
        its class:'Box' subspaces.
        When this functions receives an acquisition-integer tuple as its `target_func`,it
        optimizes each of the individual V functions making up `target_func`, i.e.
        evaluating `num_initial_samples` samples, running `num_optimization_runs` runs, and
        (if necessary) running `num_recovery_runs` recovery run for each of the individual
        V functions.
        :param space: The space over which to search.
        :param target_func: The function to maximise, with input shape [..., V, D] and output shape
                [..., V].
        :return: The V points in ``space`` that maximises``target_func``, with shape [V, D].
        """
        if isinstance(target_func, tuple):  # check if we need a vectorized optimizer
            target_func, V = target_func
        else:
            V = 1

        if V < 0:
            raise ValueError(f"vectorization must be positive, got {V}")

        candidates = tf.cast(space.sample(num_initial_samples*V)[:, None, :], dtype=tf.float32)  # [num_initial_samples, 1, D]
        candidates = tf.reshape(candidates, [-1, V, DIM])
        tiled_candidates = candidates

        target_func_values = target_func(tiled_candidates)  # [num_samples, V]
        tf.debugging.assert_shapes(
            [(target_func_values, ("_", V))],
            message=(
                f"""
                The result of function target_func has shape
                {tf.shape(target_func_values)}, however, expected a trailing
                dimension of size {V}.
                """
            ),
        )
        initial_points = tiled_candidates

        (
            successes,
            fun_values,
            chosen_x,
        ) = _perform_parallel_continuous_optimization(  # [num_optimization_runs, V]
            target_func,
            space,
            initial_points,
            optimizer_args,
        )

        best_run_ids = tf.math.argmax(fun_values, axis=0)  # [V]
        chosen_points = tf.gather(
            tf.transpose(chosen_x, [1, 0, 2]), best_run_ids, batch_dims=1
        )  # [V, D]

        return chosen_points

    return optimize_continuous

BATCH_DATA = 400
MAX_KERNELS = 1500
wdir = './work/'
dbg_dir = './dbg/'
num_layers = 3
num_hidden = 4
disable_ra = False
disable_ic = False
max_cliques = 13
poten_reg = 0.65
if not os.path.exists(wdir):
    os.makedirs(wdir)
workingdir = wdir

N_INIT = 15

def ssh_exec_blocking(cmd, stdout=None, block = True, env = None):
    if env is None:
        env = os.environ.copy()
    p = subprocess.Popen(cmd.split(' '), shell=False, stdout=stdout, stderr=stdout, env = env)
    if block:
        p.wait()
    return p

param_file = os.path.join(workingdir, 'param_file.p')
cmd_str = './the_script.sh python -u ./mujoco_server_GOLD.py %s %s %s %s %s %s %s %f' % (param_file, scenario, agent_conf, num_layers, num_hidden, disable_ra, disable_ic, poten_reg)

ssh_exec_blocking(cmd_str)
DIM_LIST = pickle.load(open(param_file, 'rb'))
DIM = sum(DIM_LIST)



hessians = []
processes = []
files = glob.glob(os.path.join(workingdir, '*'))

for f in files:
    os.remove(f)
    
for i in range(NUM_EPISODES):
    finame =os.path.join(workingdir, str(i))

    os.mkfifo(finame + '_pipe_send')
    os.mkfifo(finame + '_pipe_recv')
    env = os.environ.copy()


    cmd_str = './the_script.sh python -u ./mujoco_server_GOLD.py %s %s %d %s %s %s %s %s %s %f' % (finame+ '_pipe_send', finame + '_pipe_recv', 1, scenario, agent_conf, num_layers, num_hidden, disable_ra, disable_ic, poten_reg)

    processes.append(ssh_exec_blocking(cmd_str, env = env, block = False))

def branin_emb(x):
    """x is assumed to be in [0, 1]^d"""
    xsnumpy = x.numpy()


    global X
    global Y

    global hessians
    if len(xsnumpy.shape) == 1:
        xsnumpy = np.expand_dims(xsnumpy, axis = 0)
    
    outcomes_outer = []

    for i in range(xsnumpy.shape[0]):
        for j in range(NUM_EPISODES):
            finame = os.path.join(workingdir, str(j))
            fifo_write = open(finame + '_pipe_send', 'wb')
            str_out = pickle.dumps((xsnumpy[i].reshape(-1), EPOCH_LEN))
            fifo_write.write(str_out)
            fifo_write.close()
        
        vals = []
        outcomes = []
    
        for j in range(NUM_EPISODES):
            finame = os.path.join(workingdir, str(j))
            fifo_read = open(finame + '_pipe_recv', 'rb')
            input_stuff = fifo_read.read()
            fifo_read.close()
            (rew, hess) = pickle.loads(input_stuff)
            outcomes.append((np.sum(rew), hess))
        rew_output = np.mean([oi[0] for oi in outcomes])
        hess_output = np.mean(np.array([oi[1] for oi in outcomes]), axis = 0)

        outcomes_outer.append((rew_output, hess_output))
        hessians.append(hess_output)
    torchout = tf.expand_dims(tf.constant(np.array([oi[0] for oi in outcomes_outer], dtype = np.float32)), axis = -1)
    return torchout


'''
def contract_cliques(hess, cliques):
    contract_cliques = []
    consumed_cliques = [0] * len(cliques)
    def hess_sum(cl1, cl2):
        sum_ = 0.
        for i in cl1:
            for j in cl2:
                sum_ += hess[i,j] + hess[j,i]
        return sum_

    for i, ci in enumerate(cliques):
        if len(ci) >= 6:
            contract_cliques.append(ci)
            consumed_cliques[i] = 1
        else:
            if consumed_cliques[i]:
                continue
            ci_builder = []
            ci_builder.extend(ci)
            consumed_cliques[i] = 1

            while len(ci_builder) < 6:
                best_idx = -1
                best_val = 0.
                for j in range(i+1, len(cliques)):
                    if consumed_cliques[j] == 0 and hess_sum(ci_builder, cliques[j]) > best_val\
                        and len(set(cliques[j] + ci_builder)) <= 6:
                        best_val = hess_sum(ci_builder, cliques[j])
                        best_idx = j
                if best_idx != -1:
                    ci_builder.extend(cliques[best_idx])
                    ci_builder = list(set(ci_builder))
                    consumed_cliques[best_idx] = 1
                else:
                    break
            contract_cliques.append(ci_builder)
    cliques = contract_cliques
    return cliques
'''


def geo_mean(iterable):
    a = np.array(iterable)
    return a.prod()**(1.0/len(a))

def build_additive_kernels():
    global hessians
    global DIM_LIST
    ub_percentile = 100.
    lb_percentile = 60.

    hess = np.mean(np.stack(hessians), axis = 0)
    breaks = [sum(DIM_LIST[0:i]) for i in range(len(DIM_LIST)+1)]
    breaks_indexer = {}
    means_list = dict()
    vars_list = dict()


    for i, bi in enumerate(breaks):
        if i == 0:
            continue
        
        for ki in range(breaks[i-1], bi):
            breaks_indexer[ki] = i
        
        for j, bj in enumerate(breaks):
            if j == 0:
                continue
        
            means_list[i,j] = np.mean(hess[breaks[i-1]:bi, breaks[j-1]:bj])
            vars_list[i,j] = np.var(hess[breaks[i-1]:bi, breaks[j-1]:bj])

    for i in range(hess.shape[0]):
        for j in range(hess.shape[1]):
            bi = breaks_indexer[i]
            bj = breaks_indexer[j]
            #means = geo_mean([means_list[bi, bi], means_list[bj, bj]])
            #variance = geo_mean([vars_list[bi, bi], vars_list[bj, bj]])
            means = means_list[bi,bj]
            variance = vars_list[bi,bj]
            hess[i,j] = (hess[i,j] - means)/ np.sqrt(variance)
    
    for i in range(hess.shape[0]):
        hess[i,i] = np.min(hess) - 5.0

    for inner_loop in range(10):
        g = networkit.Graph(n = hess.shape[0])
        pct = np.percentile(hess, (ub_percentile + lb_percentile)/2.)
        hess2 = hess > pct
        for i in range(hess.shape[0]):
            for j in range(i):
                if hess2[i,j]:
                    g.addEdge(i,j)
        mc = networkit.clique.MaximalCliques(g)
        mc.run()
        cliques = mc.getCliques()


        if np.max([len(ci) for ci in cliques]) > max_cliques or len(cliques) > 1500:
            lb_percentile = (ub_percentile + lb_percentile)/2.
        else:
            ub_percentile = (ub_percentile + lb_percentile)/2.
    kernels = []
    ki = None

    variances = []
    lengthscales = []
    selectors = []
    for ci in cliques:
        variances.append(1.)
        selectors.append(ci)
        lengthscales.append([1.] * len(ci))
    
    ki = gpflow.kernels.Matern52List(variances, selectors, lengthscales)
    return (hess > pct), cliques, ki


observer = trieste.objectives.utils.mk_observer(branin_emb)
search_space = Box([-1.]*DIM, [1.]*DIM)  # define the search space directly

initial_query_points = tf.cast(search_space.sample_sobol(N_INIT), dtype=tf.float32)
initial_data = observer(initial_query_points)

cliques = None

def create_model(data, batch_size, indp):
    my_addition, cliques_, kernels = build_additive_kernels()
    global cliques
    cliques = cliques_
    model = None
    trainable_likelihood = True
    likelihood_variance = None
    empirical_mean, empirical_variance, _ = _get_data_stats(data)
    trainable_inducing_points = True
    empirical_mean, empirical_variance, num_data_points = _get_data_stats(data)
    optimizer = Optimizer(
        optimizer=Scipy_fast(), compile = True,
        minimize_args={"options": {'maxiter':200, 'disp' : True}, "compile" : True}
    )
    
    model_likelihood = gpflow.likelihoods.Gaussian()
    kernel = kernels
    #kernel = gpflow.kernels.stationaries.Matern52()
    mean = _get_mean_function(empirical_mean)

    inducing_points = _get_inducing_points(search_space, indp)
    inducing_point_selector = KMeansInducingPointSelector(search_space)

    gpflow_model = SVGP_Opt(
        kernel,
        model_likelihood,
        inducing_points,
        mean_function=mean,
        num_data=num_data_points
    )
    gpflow_model.batch_size = batch_size
    _set_gaussian_likelihood_variance(gpflow_model, empirical_variance, likelihood_variance)
    gpflow.set_trainable(gpflow_model.likelihood, trainable_likelihood)
    gpflow.set_trainable(gpflow_model.inducing_variable, trainable_inducing_points)
    if disable_ic == 'True' or disable_ra == 'True':
        model = SparseVariational(gpflow_model, optimizer = optimizer
        )
    else:
        model = SparseVariational(gpflow_model, optimizer = optimizer, 
        inducing_point_selector = inducing_point_selector
        )
    return model

def prep_batches(data, batch_size):
    X = data.query_points
    Y = data.observations
    X_numpy = X.numpy()
    Y_numpy = Y.numpy()

    num_batches = int(X_numpy.shape[0]) // int(batch_size)
    rem = X_numpy.shape[0] % batch_size
    round_up = batch_size - rem
    shuf = np.random.choice(np.arange(0, X_numpy.shape[0], dtype = np.int32), size =X_numpy.shape[0], replace = False)
    X_numpy = X_numpy[shuf]
    Y_numpy = Y_numpy[shuf]

    shuf_rep = np.random.choice(np.arange(0, X_numpy.shape[0], dtype = np.int32), size = round_up, replace = True)
    X_rep = X_numpy[shuf_rep]
    Y_rep = Y_numpy[shuf_rep]

    X_all = np.concatenate([X_numpy, X_rep], axis = 0)
    Y_all = np.concatenate([Y_numpy, Y_rep], axis = 0)
    X_all = X_all.reshape((-1, X_numpy.shape[1]))
    Y_all = Y_all.reshape((-1, 1))
    
    return Dataset(X_all, Y_all)

data = initial_data

INNER_LOOP_LENGTH = 50
global_step = 0 #TODO: MAKE SURE THIS IS UPDATED EVERY STEP!!



#do 50 queries to get a better hessian.
at_end = 50 + data.observations.shape[0]
batch_size = at_end
model = create_model(data, batch_size, 50)
copt = generate_continuous_optimizer(
    num_initial_samples = 100,
    num_optimization_runs = 5,
    optimizer_args =  {"options": {'maxiter':200, 'disp' : True}}
)

beta = 0.5*np.log((global_step + 50.0))
#beta = 2.5
acq = NegativeLowerConfidenceBound_(beta = beta)
rule = EfficientGlobalOptimization(builder=acq, optimizer=copt)  # type: ignore

randint = np.random.randint(1234567890)
for j in range(50):
    prepped_data = prep_batches(data, batch_size)
    print('doing updates and recording times')
    print(datetime.datetime.now())

    for retry in range(10):
        try:
            if retry > 5:
                model = create_model(data, batch_size, 50)
            model.update(prepped_data)
            print(datetime.datetime.now())
            model.optimize(prepped_data)
            print(datetime.datetime.now())
            break
        except Exception as e:
            print('exception!!')
            print(e.message)
            print(e.args)
            print(e)
            pass

    pnts = rule.acquire(search_space, {'OBJECTIVE': model}, None)
    print(datetime.datetime.now())

    new_obs = observer(pnts)
    
    data = Dataset(tf.concat([data.query_points, new_obs.query_points], axis = 0), 
                   tf.concat([data.observations, new_obs.observations], axis = 0))
    print('global step')
    print(datetime.datetime.now())
    global_step+= 1


    #if global_step < 300:
    #    initial_query_points = tf.cast(search_space.sample(1), dtype=tf.float32)
    #    initial_data = observer(initial_query_points)
    #    
    #    data = Dataset(tf.concat([data.query_points, initial_data.query_points], axis = 0),
    #                   tf.concat([data.observations, initial_data.observations], axis = 0))
    # 
    #    global_step += 1


pickle.dump((data, hessians, cliques), open(os.path.join(dbg_dir, '%s_%s_%s_%s_%s_%s_%d_%f_%d_dbg_file.p' % (scenario, agent_conf, num_layers, num_hidden, disable_ra, disable_ic, max_cliques, poten_reg, randint)),'wb'))

rpt = 21
if 'particle' in scenario:
    rpt = 21

for i in range(rpt):
    at_end = data.observations.shape[0] + INNER_LOOP_LENGTH*BATCH_V
    if at_end > BATCH_DATA:
        for k in range(BATCH_DATA, 1, -1):
            if (at_end // k) == (data.observations.shape[0]//k):
                batch_size = k
                break
    else:
        batch_size = at_end
    indp = min(at_end, 130)

    model = create_model(data, batch_size, indp)
    copt = generate_continuous_optimizer(
        num_initial_samples = 100,
        num_optimization_runs = 5,
        optimizer_args =  {"options": {'maxiter':200, 'disp' : True}}
    )

    beta = 0.5*np.log((global_step + INNER_LOOP_LENGTH))
    #beta = 2.5
    acq = NegativeLowerConfidenceBound_(beta = beta)
    rule = EfficientGlobalOptimization(builder=acq, optimizer=copt)

    for j in range(INNER_LOOP_LENGTH):
        prepped_data = prep_batches(data, batch_size)
        print('doing updates and recording times')
        print(datetime.datetime.now())


        for retry in range(10):
            try:
                if retry > 5:
                    model = create_model(data, batch_size, indp)
                model.update(prepped_data)
                print(datetime.datetime.now())
                model.optimize(prepped_data)
                print(datetime.datetime.now())
                break
            except Exception as e:
                print(e)
                pass



        pnts = rule.acquire(search_space, {'OBJECTIVE': model}, None)
        print(datetime.datetime.now())

        new_obs = observer(pnts)

        data = Dataset(tf.concat([data.query_points, new_obs.query_points], axis = 0),
                       tf.concat([data.observations, new_obs.observations], axis = 0))

        global_step += 1

        #if global_step < 300:
        #    initial_query_points = tf.cast(search_space.sample(1), dtype=tf.float32)
        #    initial_data = observer(initial_query_points)
        #    
        #    data = Dataset(tf.concat([data.query_points, initial_data.query_points], axis = 0),
        #                   tf.concat([data.observations, initial_data.observations], axis = 0))
        # 
        #    global_step += 1

        print('global step')
        print(datetime.datetime.now())
    
    pickle.dump((data, hessians, cliques), open(os.path.join(dbg_dir, '%s_%s_%s_%s_%s_%s_%d_%f_%d_dbg_file.p' % (scenario, agent_conf, num_layers, num_hidden, disable_ra, disable_ic, max_cliques, poten_reg, randint)),'wb'))
