import numpy as np
import os
import os.path as osp
import torch
from torch import Tensor
import time
import argparse
from tqdm import tqdm
from dataclasses import dataclass, asdict
from datetime import datetime
import os.path as osp

from botorch import fit_gpytorch_mll
from botorch.models.gp_regression import SingleTaskGP
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import unnormalize, normalize
from botorch.optim.optimize import optimize_acqf
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

from botorch.acquisition.monte_carlo import (
    qExpectedImprovement,
    qProbabilityOfImprovement,
    qUpperConfidenceBound,
)
from botorch.acquisition.objective import ConstrainedMCObjective, GenericMCObjective

from utils.log import get_logger, log_fn
from utils.paths import (
    get_log_filepath,
    get_result_data_path,
    params_to_filename,
    get_result_plot_path,
)
from utils.data import set_all_seeds
from evaluation_utils import plot_optimization_batch, get_opt_dataset

from data.function import get_function_environment 
from evaluation_utils import from_seed_to_data, from_function_name_to_datapaths


BATCH_LIMIT = 1
MAXITER = 200
NEGATE = True  
PLOT_ENABLED = False


@dataclass
class GPSingleObjectiveConfig:
    seed: int = 0
    function_name: str = "Branin"
    acq_fn_name: str = "qEI"
    T: int = 100
    batch_size: int = 1
    num_candidates: int = 1
    num_initial_points: int = 1
    regret_type: str = "simple"
    num_restarts: int = 10
    raw_samples: int = 512
    mc_samples: int = 128
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    expid: str = time.strftime("%Y%m%d-%H%M%S")
    normalize_outputs: bool = True
    beta: float = 2.0  
    plot_enabled: float = PLOT_ENABLED
    override: bool = True 

    def __post_init__(self):
        valid_acq_fns = ["qEI", "qPI", "qUCB", "rs"]
        if self.acq_fn_name not in valid_acq_fns:
            raise ValueError(f"acq_fn_name must be one of {valid_acq_fns}")


def save_fig(
    config, fig, override: bool = False, log: callable = print, filename=None
) -> str:
    if fig is None: 
        log("No figure to save.")
        return ""
    path = get_result_plot_path(
        model_name=config.acq_fn_name,
        expid=config.expid,
        task_type="optimization",
        suffix=config.function_name,
    )
    path = osp.join(path, str(config.seed))

    filename = (
        params_to_filename(params=asdict(config)) if filename is None else filename
    )
    fig_path = osp.join(path, f"{filename}.png")

    if osp.exists(fig_path) and not override:
        ts_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        path = osp.join(path, ts_str)
        fig_path = osp.join(path, f"{filename}.png")
        log(f"\nFigure already exists. New figure will be saved to {fig_path}")
        return fig_path

    log(f"Saving figure to {fig_path}\n")
    os.makedirs(path, exist_ok=True)
    fig.savefig(fig_path)

    return fig_path


def save_data(
    config,
    filename: str,
    data: Tensor,
    override: bool = False,
    log: callable = print,
):
    path = get_result_data_path(
        model_name=config.acq_fn_name,
        expid=config.expid,
        task_type="optimization",
        suffix=config.function_name,
    )
    path = osp.join(path, str(config.seed))
    data_path = osp.join(path, f"{filename}.pt")

    if osp.exists(data_path) and not override:
        ts_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        new_path = osp.join(path, ts_str)
        data_path = osp.join(new_path, f"{filename}.pt")
        os.makedirs(new_path, exist_ok=True)
        log(f"\nFile already exists. Saving to new path: {data_path}")
    elif osp.exists(data_path) and override:
        os.makedirs(path, exist_ok=True)
        log(f"\nOverriding existing file at {data_path}")
    else:
        os.makedirs(path, exist_ok=True)
        log(f"\nFile does not exist. Saving to {data_path}")

    torch.save(data, data_path)

    return data_path


def preprocess_datapoints(train_x, train_y, x_bounds, tkwargs):
    """Normalize x, and negate y if needed."""
    if train_x.ndim == 3:
        train_x = train_x.squeeze(0)
    if train_y.ndim == 3:
        train_y = train_y.squeeze(0)

    train_x = train_x.to(**tkwargs)
    train_y = train_y.to(**tkwargs)

    if train_y.ndim == 1:
        train_y = train_y.unsqueeze(-1)

    standard_bounds = torch.zeros((2, train_x.shape[-1]), **tkwargs)
    standard_bounds[1] = 1

    x_bounds = x_bounds.to(**tkwargs).transpose(0, 1)
    train_x_normalized = normalize(train_x, bounds=x_bounds)

    if NEGATE:
        train_y = -train_y

    return train_x_normalized, train_y, standard_bounds


def fit_gp(train_x: Tensor, train_y: Tensor):
    """Define single-objective GP model."""
    model = SingleTaskGP(train_x, train_y)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    model.to(train_x)
    mll.to(train_x)
    return mll, model


def optimize_acquisition_function(
    model: SingleTaskGP,
    train_y: Tensor,
    standard_bounds: Tensor,
    sampler,
    acq_fn_name: str,
    num_candidates: int,
    num_restarts: int,
    raw_samples: int,
    beta: float,
    tkwargs,
):
    """
    Optimize single-objective acquisition function.
    Returns: x_best [num_candidates, dx], inference_time (float)
    """
    t1 = time.time()

    best_value = train_y.max()

    if acq_fn_name == "qEI":
        acq_fn = qExpectedImprovement(
            model=model,
            best_f=best_value,
            sampler=sampler,
        )
    elif acq_fn_name == "qPI":
        acq_fn = qProbabilityOfImprovement(
            model=model,
            best_f=best_value,
            sampler=sampler,
        )
    elif acq_fn_name == "qUCB":
        acq_fn = qUpperConfidenceBound(
            model=model,
            beta=torch.tensor(beta, **tkwargs),
            sampler=sampler,
        )
    elif acq_fn_name == "qTS": 
        raise NotImplementedError("qTS not implemented yet.")
    else:
        raise ValueError(f"Unknown acquisition function: {acq_fn_name}")

    candidates, _ = optimize_acqf(
        acq_function=acq_fn,
        bounds=standard_bounds,
        q=num_candidates,
        num_restarts=num_restarts,
        raw_samples=raw_samples,
        options={"batch_limit": BATCH_LIMIT, "maxiter": MAXITER},
        sequential=True,
    )
    t2 = time.time()
    inference_time = t2 - t1

    if candidates.ndim == 2:
        candidates = candidates.unsqueeze(0)

    return candidates, inference_time


def main(config: GPSingleObjectiveConfig = None):
    config = GPSingleObjectiveConfig() if config is None else config

    tkwargs = {
        "dtype": torch.double,
        "device": config.device,
    }
    torch.set_default_dtype(torch.double)
    torch.set_default_device(config.device)
    set_all_seeds(config.seed)

    log_filename = get_log_filepath(group_name=config.acq_fn_name, expid=config.expid)
    logger = get_logger(file_name=log_filename, mode="w")
    log = log_fn(logger)
    log(f"log_filename:\t{log_filename}\nconfiguration:\n{asdict(config)}\n")

    train_x, train_y, train_x_bounds, train_y_bounds = get_opt_dataset(
        function_name=config.function_name,
        seed=config.seed,
        device=config.device,
    )

    log(f"setup function: {config.function_name}")
    test_function = get_function_environment(
        function_name=config.function_name,
        sigma=0.0, 
        train_x=train_x,
        train_y=train_y,
        train_x_bounds=train_x_bounds,
        train_y_bounds=train_y_bounds,
    )
    x_bounds = test_function.x_bounds.to(**tkwargs)

    train_x, train_y, _, regret = test_function.init(
        input_bounds=x_bounds,
        batch_size=1,
        num_initial_points=config.num_initial_points,
        regret_type=config.regret_type,
        compute_hv=False,
        compute_regret=True,
        device=config.device,
    )

    log(f"Initial train_x:\n{train_x}")
    log(f"Initial train_y:\n{train_y}")
    log(f"Initial simple regret:\t{regret}\n")

    regret_list = [torch.from_numpy(regret).to(**tkwargs)]
    time_list = [0.0]

    for step in tqdm(
        range(1, config.T + 1), f"Optimization loop for config.T={config.T} steps"
    ):
        if config.acq_fn_name == "rs":
            t1 = time.time()
            x_next = torch.rand(
                (1, config.num_candidates, x_bounds.shape[0]),
                **tkwargs,
            )
            x_next = unnormalize(x_next, bounds=x_bounds.transpose(0, 1))
            inference_time = time.time() - t1
        else:
            try:
                train_x_tfm, train_y_tfm, standard_bounds = preprocess_datapoints(
                    train_x=train_x, train_y=train_y, x_bounds=x_bounds, tkwargs=tkwargs
                )

                t1 = time.time()
                mll, model = fit_gp(train_x=train_x_tfm, train_y=train_y_tfm)
                fit_gpytorch_mll(mll)
                gp_fitting_time = time.time() - t1

                sampler = SobolQMCNormalSampler(
                    sample_shape=torch.Size([config.mc_samples])
                )
                x_next, acq_optimize_time = optimize_acquisition_function(
                    model=model,
                    train_y=train_y_tfm,
                    standard_bounds=standard_bounds,
                    sampler=sampler,
                    acq_fn_name=config.acq_fn_name,
                    num_candidates=config.num_candidates,
                    num_restarts=config.num_restarts,
                    raw_samples=config.raw_samples,
                    beta=config.beta,
                    tkwargs=tkwargs,
                )
                inference_time = gp_fitting_time + acq_optimize_time
                x_next = unnormalize(x_next.detach(), bounds=x_bounds.transpose(0, 1))

            except Exception as e:
                log(f"Error at step {step}: {e}")
                break

        train_x, train_y, _, regret = test_function.step(
            input_bounds=x_bounds,
            x_new=x_next,
            x_ctx=train_x,
            y_ctx=train_y,
            compute_hv=False,
            compute_regret=True,
            regret_type=config.regret_type,
        )
        regret_list.append(torch.from_numpy(regret).to(**tkwargs))
        time_list.append(inference_time)
        log(
            f"Step {step}: y_next: {train_y[:, -config.num_candidates:]}, Simple Regret: {regret}, Time (s): {inference_time:.4f}"
        )

    final_regret = regret_list[-1]
    final_time = time_list[-1]
    log(
        f"[results, seed={config.seed}]\nfinal_regret:\t{final_regret}\nfinal_time (s):\t{final_time}"
    )


    regret_stack = torch.stack(regret_list, dim=-1)  # [batch_size, T+1]
    time_stack = torch.tensor(time_list)  # [T+1]
    result_dict = {
        "regret": regret_stack.detach().cpu(),
        "time": time_stack.detach().cpu(),
    }

    for key, val in result_dict.items():
        save_data(
            config=config, filename=key, data=val, override=config.override, log=log
        )

    if config.plot_enabled:
        fig = plot_optimization_batch(
            test_function=test_function,
            x_query=train_x,
            y_query=train_y,
        )
        save_fig(config=config, fig=fig, override=config.override, log=log)

    print("\nOptimization complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--function_name", type=str, default="Branin")
    parser.add_argument("--acq_fn_name", type=str, default="qEI")
    parser.add_argument("--num_candidates", type=int, default=1)
    parser.add_argument("--num_restarts", type=int, default=10)
    parser.add_argument("--raw_samples", type=int, default=512)
    parser.add_argument("--mc_samples", type=int, default=128)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--expid", type=str, default=time.strftime("%Y%m%d-%H%M%S"))
    parser.add_argument("--num_initial_points", type=int, default=1)
    parser.add_argument("--T", type=int, default=100)
    parser.add_argument(
        "--beta", type=float, default=2.0, help="UCB exploration parameter"
    )
    parser.add_argument("--plot_enabled", type=bool, default=PLOT_ENABLED)
    parser.add_argument("--override", type=bool, default=True)
    
    args = parser.parse_args()
    arg_dict = vars(args)

    config = GPSingleObjectiveConfig(**arg_dict)
    main(config)
