import os
import sys
import logging
from datetime import datetime
import numpy as np
import torch
import warnings

from functions_mujoco import Humanoid
from grabbo.grabbo import GRABBO
from grabbo.util.behaviors import BaxusBehavior
from grabbo.util.behaviors.gp_configuration import GPBehaviour
from botorch.exceptions import InputDataWarning

warnings.filterwarnings("ignore", category=InputDataWarning)

BENCHMARK_NAME = "humanoid"
ANT_DIM = 6392
NUM_INIT = 30
NUM_ITERS = 1000
TARGET_DIM = 1
NOISE_STD = 0
REPEAT = 9

bounds = [[-1.0] * ANT_DIM, [1.0] * ANT_DIM]
RESULTS_BASE_DIR = "results"
benchmark_results_dir = os.path.join(RESULTS_BASE_DIR, BENCHMARK_NAME)
os.makedirs(benchmark_results_dir, exist_ok=True)

LOG_FORMAT = "%(asctime)s %(levelname)s: %(message)s"

class NormalizedBenchmark:
    def __init__(self, base_benchmark):
        self.f = base_benchmark
        self.lb_vec = np.zeros(base_benchmark.dims, dtype=np.float32)
        self.ub_vec = np.ones(base_benchmark.dims, dtype=np.float32)
        self.dim = base_benchmark.dims
        self.fun_name = base_benchmark.name

    def __call__(self, x):
        x = np.array(x, dtype=np.float32)
        x_scaled = 0.5 * (x + 1.0)
        return self.f(x_scaled)

    def evaluate_true(self, x):
        x = np.array(x, dtype=np.float32)
        x_scaled = 0.5 * (x + 1.0)
        val = self.f(x_scaled)
        return torch.tensor(val, dtype=torch.float32)

seeds = [i for i in range(REPEAT)]
for repeat_id in range(REPEAT):
    SEED = seeds[repeat_id]
    run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + f"_repeat{repeat_id}_seed{SEED}"
    run_dir = os.path.join(benchmark_results_dir, run_timestamp)
    os.makedirs(run_dir, exist_ok=True)
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.INFO,
        format=LOG_FORMAT,
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(os.path.join(run_dir, f"{BENCHMARK_NAME}_grabbo.log"))
        ]
    )

    logging.info(f"{repeat_id+1}/{REPEAT} SEED = {SEED}")

    try:
        benchmark = Humanoid()
        benchmark = NormalizedBenchmark(benchmark)
        grabbo = GRABBO(
            f=benchmark,
            target_dim=TARGET_DIM,
            n_init=NUM_INIT,
            max_evals=NUM_ITERS,
            run_dir=run_dir,
            behavior=BaxusBehavior(),
            gp_behaviour=GPBehaviour()
        )

        grabbo.optimize()
        logging.info("Finished")
        fx_true = []
        for x in grabbo.X:
            fx = benchmark.evaluate_true(x)
            fx_true.append(float(fx.item()))
        np.savetxt(os.path.join(run_dir, "fX_true.csv"), fx_true, delimiter=",")

        logging.info(f"saved at: {os.path.join(run_dir, 'fX_true.csv')}")
        logging.info(f"real values: {min(fx_true):.6f}")

    except Exception as e:
        logging.exception("error")
        continue
