from dataset import N101, N201, N301
from Graph_GP import graphGP_fit
from perf_metrics import *
import tensorflow as tf

def NasGoat_BO(dataset_name, init_num, kernel_exp=True, num_iteration=30, batch=5, task="cifar10-valid", noisy=False, show_log=False):
    '''Main BO loop running NAS-GOAT on NAS benchmarks
    Args:
        dataset_name (str): "N101" or "N201", represent "NAS-Bench-101" and "NAS-Bench-201" respectively
        init_num (int): number of initial samples
        kernel_exp (boolean): True for exponential kernel, False for linear kernel
        num_iteration (int): number of itertions
        batch (int): number of architectures drawn at each iteration
        task (str): for NAS-Bench-201 dataset, choose from "cifar10-valid", "cifar100" and "ImageNet16-120"
        noisy (boolean): True for noisy objective, False for deterministic objective
        show_log (boolean): True for showing Gurobi optimizing process log, False for closing log
    Returns:
        log (array): results saved as a numpy array, each row record the following iteration-wisely:
                    log validation error of evaluated graph, log test error of evaluated graph, current best log validation error, current best log test error, iteration runtime
    '''
    # Preparing dataset and parameters
    beta_t = 3
    if dataset_name == "N101":
        dataset = N101(path="data/")
        task = "cifar10-valid"
        # randomly sample initial graph samples from dataset
        sample_idxs = np.random.randint(423624, size=init_num)
        G = dataset.index_sampling(sample_idxs)
    elif dataset_name == "N201" or dataset_name == "N301":
        if dataset_name == "N201":
            dataset = N201(path="data/")
            assert task in ['cifar10-valid', 'cifar100', 'ImageNet16-120']
        else:
            dataset = N301(path="data/", surrogate="xgb")
        # randomly sample connected graphs as inital samples for GP using MIP
        randomMIP = dataset.get_MIP(G=[], beta_t=beta_t, GPmodel=None, random=True)
        randomMIP.Params.PoolSearchMode = 2
        randomMIP.Params.PoolSolutions = init_num
        randomMIP.update()
        randomMIP.optimize()
        G = dataset.get_optimized_graph(randomMIP)
    else:
        raise NotImplementedError("Dataset not implemented yet.")
    if dataset_name == "N301":
        ys = np.array([dataset.eval(graph=g, score="log-err", noisy=noisy) for g in G])[..., None].astype("float64")
        init_g = np.min(ys)
        log = [[init_g, init_g, 0.]]
    else:
        y_val_test = np.array([list(dataset.eval(graph=g, score="log-err", task=task, noisy=noisy)) for g in G])
        ys = y_val_test[:, 0][..., None].astype("float64")
        y_ts = y_val_test[:, 1][..., None].astype("float64")
        init_g_val, init_g_test = np.min(ys), np.min(y_ts)
        log = [[init_g_val, init_g_test, init_g_val, init_g_test, 0.]]

    for t in range(num_iteration):
        print(f"Iteration {t+1}:")
        # Scale y
        ys_mean, ys_std = ys.mean(), ys.std()
        Y_ = tf.convert_to_tensor(ys)
        Y = (Y_ - ys_mean) / ys_std

        # Fit GP using G, Y and suitable kernel
        kernel = dataset.get_kernel(kernel_exp)
        GPmodel = graphGP_fit(G, Y, kernel)

        # Formulate MIP model
        if dataset_name == "N101":
            G_nexts = []
            LCBs = []
            for n in [6, 7]:
                MIPmodel = dataset.get_MIP(G=G, beta_t=beta_t, GPmodel=GPmodel, N=n, random=False)
                # Optimization
                MIPmodel.Params.TimeLimit = 1800
                MIPmodel.Params.FeasibilityTol = 1e-4
                MIPmodel.Params.OptimalityTol = 1e-4

                if batch > 1:
                    MIPmodel.Params.PoolSearchMode = 2
                    MIPmodel.Params.PoolSolutions = batch

                MIPmodel.update()
                if not show_log:
                    MIPmodel.Params.LogToConsole = 0
                MIPmodel.optimize()

                if MIPmodel.SolCount == batch or batch == 1:
                    Gs, lcbs = dataset.get_optimized_graph(MIPmodel, batch=batch, N=n)
                    G_nexts.extend(Gs)
                    LCBs.extend(lcbs)
                else:
                    print("Optimization failed")
                    randomMIP = dataset.get_MIP(G=G, beta_t=beta_t, GPmodel=GPmodel, N=n, random=True)
                    # random sampling solutions
                    randomMIP.Params.PoolSearchMode = 2
                    randomMIP.Params.PoolSolutions = batch - MIPmodel.SolCount
                    randomMIP.Params.Seed = t
                    randomMIP.update()
                    randomMIP.optimize()
                    Gs_orig, lcbs_orig = dataset.get_optimized_graph(MIPmodel, N=n)
                    G_nexts.extend(Gs_orig)
                    LCBs.extend(lcbs_orig)
                    Gs_rand, lcbs_rand = dataset.get_optimized_graph(randomMIP, N=n)
                    G_nexts.extend(Gs_rand)
                    LCBs.extend(lcbs_rand)
            keep_idxs = np.argpartition(LCBs, batch)[:batch]
            G_nexts = [G_nexts[i] for i in keep_idxs]

        else:
            MIPmodel = dataset.get_MIP(G=G, beta_t=beta_t, GPmodel=GPmodel, random=False)

            # Optimization
            MIPmodel.Params.TimeLimit = 1800
            MIPmodel.Params.FeasibilityTol = 1e-4
            MIPmodel.Params.OptimalityTol = 1e-4

            if batch > 1:
                MIPmodel.Params.PoolSearchMode = 2
                MIPmodel.Params.PoolSolutions = batch

            MIPmodel.update()
            if not show_log:
                MIPmodel.Params.LogToConsole = 0
            MIPmodel.optimize()

            if MIPmodel.SolCount == batch or batch == 1:
                G_nexts = dataset.get_optimized_graph(MIPmodel, batch=batch)
            else:
                print("Optimization failed")
                randomMIP = dataset.get_MIP(G=G, beta_t=beta_t, GPmodel=GPmodel, random=True)
                # random sampling solutions
                randomMIP.Params.PoolSearchMode = 2
                randomMIP.Params.PoolSolutions = batch - MIPmodel.SolCount
                randomMIP.Params.Seed = t
                randomMIP.update()
                randomMIP.optimize()
                G_nexts = dataset.get_optimized_graph(MIPmodel)
                G_nexts.extend(dataset.get_optimized_graph(randomMIP))

        # Update samples
        G.extend(G_nexts)
        if dataset_name == "N301":
            y_nexts = np.array([dataset.eval(graph=g, score="log-err", noisy=noisy) for g in G_nexts])
            ys = np.append(ys, y_nexts)[..., None].astype("float64")
            # Record results
            y_eval = np.min(y_nexts)
            y_best = np.min(ys)
            log.append([y_eval, y_best, MIPmodel.Runtime])
            print(
                f"Iteration {t + 1}/{num_iteration} current best graph has: val_acc={1 - np.exp(y_best)}")
        else:
            y_val_test = np.array(
                [list(dataset.eval(graph=g, score="log-err", task=task, noisy=noisy)) for g in G_nexts])
            y_nexts = y_val_test[:, 0]
            y_tests = y_val_test[:, 1]
            ys = np.append(ys, y_nexts)[..., None].astype("float64")
            y_ts = np.append(y_ts, y_tests)[..., None].astype("float64")

            # Record results
            evaluated_g = G_nexts[np.argmin(y_nexts)]
            y_val, y_test = dataset.eval(graph=evaluated_g, score="log-err", task=task, noisy=noisy)
            best_y_val, best_y_test = np.min(ys), np.min(y_ts)
            log.append([y_val, y_test, best_y_val, best_y_test, MIPmodel.Runtime])
            print(
                f"Iteration {t + 1}/{num_iteration} current best graph has: val_acc={1 - np.exp(best_y_val)}, test_acc={1 - np.exp(best_y_test)}")
        MIPmodel.dispose()

    print("Optimization complete.")
    return log
