# Copyright (c) 2022 Copyright holder of the paper Structural Kernel Search via Bayesian Optimization and Symbolical Optimal Transport submitted to NeurIPS 2022 for review.
# All rights reserved.
from bosot.configs.kernels.base_kernel_config import BaseKernelConfig
from bosot.configs.kernels.grammar_tree_kernel_kernel_configs import OTWeightedDimsExtendedGrammarKernelConfig
from bosot.configs.kernels.linear_configs import LinearWithPriorConfig
from bosot.configs.kernels.periodic_configs import PeriodicWithPriorConfig

from bosot.configs.kernels.rational_quadratic_configs import RQWithPriorConfig
from bosot.configs.kernels.rbf_configs import RBFWithPriorConfig
from bosot.configs.models.gp_model_config import BasicGPModelConfig
from bosot.configs.models.object_gp_model_config import BasicObjectGPModelConfig
from bosot.kernels.kernel_factory import KernelFactory
from bosot.models.gp_model import GPModel, PredictionQuantity
from bosot.models.model_factory import ModelFactory
from bosot.bayesian_optimization.bayesian_optimizer_objects import BayesianOptimizerObjects
from bosot.models.object_mean_functions import BICMean, ObjectConstant
from bosot.oracles.gp_model_bic_oracle import GPModelBICOracle
from bosot.configs.kernels.kernel_grammar_generators.cks_with_rq_generator_config import CKSWithRQGeneratorConfig
from bosot.configs.bayesian_optimization.bayesian_optimizer_objects_configs import ObjectBOExpectedImprovementEAConfig
import logging
import numpy as np
from bosot.bayesian_optimization.enums import AcquisitionFunctionType, ValidationType
from bosot.oracles.gp_model_evidence_oracle import GPModelEvidenceOracle
from bosot.kernels.kernel_grammar.generator_factory import GeneratorFactory
from bosot.oracles.test_oracle import TestOracle
from bosot.utils.plotter import Plotter
from bosot.utils.plotter2D import Plotter2D
from bosot.utils.utils import calculate_rmse

matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

########### Start #########################

FAST_VERSION = True

PLOT_TOY_DATA = False

############ Generate artifical dataset ####################

observation_noise = 0.03

test_oracle = TestOracle(observation_noise)

if PLOT_TOY_DATA:
    test_oracle.plot()

x_train, y_train = test_oracle.get_random_data(400)

x_test, y_test = test_oracle.get_random_data(200)

############ Configuration of Kernel seach via SOT ############

kernel_grammar_generator_config = CKSWithRQGeneratorConfig(input_dimension=2)

kernel_grammar_generator = GeneratorFactory.build(kernel_grammar_generator_config)

oracle = GPModelEvidenceOracle(x_train, y_train, kernel_grammar_generator, fast_inference=FAST_VERSION, x_test=x_test, y_test=y_test)

kernel_kernel_config = OTWeightedDimsExtendedGrammarKernelConfig(input_dimension=2)

object_gp = ModelFactory.build(BasicObjectGPModelConfig(kernel_config=kernel_kernel_config, prediction_quantity=PredictionQuantity.PREDICT_F, perform_multi_start_optimization=False))

object_gp.set_mean_function(ObjectConstant())

bo_config = ObjectBOExpectedImprovementEAConfig(n_steps_evolutionary=3, population_evolutionary=100)

optimizer = BayesianOptimizerObjects(**bo_config.dict())

optimizer.set_model(object_gp)

optimizer.set_oracle(oracle)

optimizer.set_candidate_generator(kernel_grammar_generator)

########### Start search procedure ################

optimizer.sample_train_set(24)

optimizer.maximize(40)

best_kernel_expression = optimizer.get_current_best()

print("Best kernel:")
print(best_kernel_expression)
