"""Baselines for multi-objective maximization"""

import numpy as np
import pickle
from typing import Optional
from matplotlib.pyplot import Figure
from datetime import datetime
import os
import os.path as osp
import argparse
from dataclasses import dataclass, asdict
import time
from tqdm import tqdm
import torch
from torch import Tensor
from botorch import fit_gpytorch_mll
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.outcome import Standardize
from botorch.utils.transforms import unnormalize, normalize
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch.optim.optimize import optimize_acqf, optimize_acqf_list
from botorch.acquisition.multi_objective.monte_carlo import (
    qExpectedHypervolumeImprovement,
    qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
    qHypervolumeKnowledgeGradient,
)
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.acquisition.multi_objective.logei import (
    qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.utils.sampling import sample_simplex
from botorch.acquisition.objective import GenericMCObjective

from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
from utils.log import get_logger, log_fn
from utils.paths import (
    get_log_filepath,
    get_result_data_path,
    params_to_filename,
    get_result_plot_path,
)
from utils.config import get_train_x_range, get_train_y_range
from utils.data import set_all_seeds
from data.function import get_function_environment, TestFunction
from evaluation_utils import get_opt_dataset, plot_optimization_batch, expand_metrics
from utils.plot import plot_fronts

BATCH_LIMIT = 1
MAXITER = 200
PLOT_PER_N_STEPS = -1
NEGATE = True
PLOT_ENABLED = False  # whether to plot at all
SCALE_Y_RANGE = get_train_y_range()


@dataclass
class GPBaselineConfig:
    seed: int = 0
    function_name: str = "BraninCurrin"
    acq_fn_name: str = "qNEHVI"
    T: int = 100
    batch_size: int = 1
    sigma: float = 0.0
    num_candidates: int = 1
    num_initial_points: int = 1
    regret_type: str = "ratio"
    num_restarts: int = 10
    raw_samples: int = 512
    mc_samples: int = 128
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    expid: str = time.strftime("%Y%m%d-%H%M%S")
    normalize_outputs: bool = True
    plot_enabled: bool = PLOT_ENABLED
    override: bool = True
    dim: Optional[int] = None  # function input dimension; if None, use default
    subfolder: Optional[str] = None
    suffix: Optional[str] = None  # extra suffix for saving results

    def __post_init__(self):
        if self.sigma < 0:
            raise ValueError("sigma must be non-negative.")


def get_cost_used(train_x):
    # Unit cost per step
    return train_x.shape[1]


def main(config: GPBaselineConfig = None):
    config = GPBaselineConfig() if config is None else config

    tkwargs = {
        "dtype": torch.double,
        "device": config.device,
    }
    torch.set_default_dtype(torch.double)
    torch.set_default_device(config.device)
    set_all_seeds(config.seed)

    # Setup logging
    log_filename = get_log_filepath(group_name=config.acq_fn_name, expid=config.expid)
    logger = get_logger(file_name=log_filename, mode="w")
    log = log_fn(logger)
    log(f"log_filename:\t{log_filename}\nconfiguration:\n{asdict(config)}\n")

    train_x, train_y, train_x_bounds, train_y_bounds = get_opt_dataset(
        function_name=config.function_name,
        seed=config.seed,
        subfolder=config.subfolder,
        device=config.device,
    )
    if train_x is not None and train_y is not None:
        log(f"train_x[0]: {train_x[0]}, train_y[0]: {train_y[0]}")
    log(f"setup function: {config.function_name}")
    test_function = get_function_environment(
        function_name=config.function_name,
        sigma=config.sigma,
        train_x=train_x,
        train_y=train_y,
        train_x_bounds=train_x_bounds,
        train_y_bounds=train_y_bounds,
        dim=config.dim,
    )
    log(f"x_bounds:\t{test_function.x_bounds}\n")
    log(f"y_bounds:\t{test_function.y_bounds}\n")
    log(f"max_hv:\t{test_function.max_hv:.4f}\n\n")

    x_bounds = test_function.x_bounds
    ref_point = test_function.ref_point.to(**tkwargs)
    if NEGATE:
        ref_point = -ref_point

    # Sample initial points
    train_x, train_y, hv, regret = test_function.init(
        input_bounds=x_bounds,
        batch_size=1,
        num_initial_points=config.num_initial_points,
        regret_type=config.regret_type,
        compute_hv=True,
        compute_regret=True,
        device=config.device,
    )
    train_x_tfm, train_y_tfm = preprocess_datapoints(
        test_function=test_function,
        train_x=train_x,
        train_y=train_y,
        tkwargs=tkwargs,
    )
    log(f"Initial train_x_tfm: {train_x_tfm}")
    log(f"Initial train_x:\n{train_x}")
    log(f"Initial train_y:\n{train_y}")
    log(f"Initial hv:\t{hv}\n")
    log(f"Initial regret:\t{regret}\n")

    def _save_initializations(x_star, config):
        # Save initialization (for boformer testings); change path in the future
        path = f"BOFormer/initializations/{config.function_name}/"
        if not osp.exists(path):
            os.makedirs(path, exist_ok=True)
        filepath = f"{path}/seed_{config.seed}.pt"
        if not osp.exists(filepath):
            log(f"Saved initial points to {filepath}")
            torch.save(x_star, filepath)
        else:
            log(f"File {filepath} already exists. Not overriding.")

    _save_initializations(train_x_tfm, config)

    hv = expand_metrics(hv, train_x)
    regret = expand_metrics(regret, train_x)
    hv_list = [hv]
    regret_list = [regret]
    time_list = [0.0]

    # Instantaneous hypervolume
    instantaneous_hv_list = [hv.clone()]

    if config.acq_fn_name == "BOFormer":
        # Load history
        name_map = {
            "BraninCurrin": "BC",
            "AckleyRastrigin": "ARa",
            "AckleyRosenbrock": "AR",
        }
        function_type = name_map.get(config.function_name, config.function_name)

        if (
            function_type == "OilSorbentContinuousMid"
            or function_type == "dx2_dy3"
            or function_type == "dx3_dy3"
            or function_type == "dx3_dy2"
            or function_type == "LaserPlasma"
        ):
            filepath = f"BOFormer/testings/BOFormer_model_3400_function_type_{function_type}_dim_{test_function.x_dim}_N_m_1_N_local_1_ls_learned_freq_10_initial_sample_1_online_ls_1_seed_{config.seed}_episode_0.pkl"
        else:
            filepath = f"BOFormer/testings/BOFormer_model_3000_function_type_{function_type}_dim_{test_function.x_dim}_N_m_1_N_local_1_ls_learned_freq_10_initial_sample_1_online_ls_1_seed_{config.seed}_episode_0.pkl"

        def read_history(file_path):
            with open(file_path, "rb") as f:
                data = pickle.load(f)
            return data

        history = read_history(filepath)

        history_x = np.array(history["x"])  # [n, x_dim] in [0, 1]
        history_time = np.array(history["inference_time"])  # [n, ]

    while get_cost_used(train_x) <= config.T:
        # Summarize results
        hv_stack = torch.cat(hv_list, dim=-1)  # [batch_size, T+1]
        regret_stack = torch.cat(regret_list, dim=-1)  # [batch_size, T+1]
        time_stack = torch.tensor(time_list)  # [T+1]
        instantaneous_hv_stack = torch.cat(instantaneous_hv_list, dim=-1)
        result_dict = {
            "hv": hv_stack.detach().cpu(),
            "regret": regret_stack.detach().cpu(),
            "time": time_stack.detach().cpu(),
            "instantaneous_hv": instantaneous_hv_stack.detach().cpu(),
        }

        # Save evaluation setup and results to data_save_path/res_filename.pt
        for key, val in result_dict.items():

            save_data(
                config=config, filename=key, data=val, override=config.override, log=log
            )

        step = get_cost_used(train_x)
        if config.acq_fn_name == "rs":
            # Random search baseline
            t1 = time.time()
            x_next = torch.rand(
                (1, config.num_candidates, x_bounds.shape[0]),
                **tkwargs,
            )
            x_next = unnormalize(x_next, bounds=x_bounds.transpose(0, 1))
            inference_time = time.time() - t1
        elif config.acq_fn_name == "BOFormer":
            if step >= len(history_x):
                log(f"History does not have step {step}, stopping optimization.")
                break

            x_next = history_x[step : step + 1]  # [1, x_dim]
            x_next = (
                torch.from_numpy(x_next).to(**tkwargs).unsqueeze(1)
            )  # [1, 1, x_dim]
            x_next = unnormalize(x_next.detach(), bounds=x_bounds.transpose(0, 1))
            inference_time = history_time[step - 1]
        else:
            try:
                train_x_tfm, train_y_tfm = preprocess_datapoints(
                    test_function=test_function,
                    train_x=train_x,
                    train_y=train_y,
                    tkwargs=tkwargs,
                )
                standard_bounds = get_standard_bounds(train_x.shape[-1], tkwargs)

                t1 = time.time()
                mll, model = fit_gp(
                    train_x=train_x_tfm,
                    train_y=train_y_tfm,
                    normalize_outputs=config.normalize_outputs,
                )
                fit_gpytorch_mll(mll)
                gp_fitting_time = time.time() - t1

                sampler = SobolQMCNormalSampler(
                    sample_shape=torch.Size([config.mc_samples])
                )
                x_next, acq_optimize_time = optimize_acquisition_function(
                    model=model,
                    train_x=train_x_tfm,
                    train_y=train_y_tfm,
                    standard_bounds=standard_bounds,
                    ref_point=ref_point,
                    sampler=sampler,
                    acq_fn_name=config.acq_fn_name,
                    num_candidates=config.num_candidates,
                    num_restarts=config.num_restarts,
                    raw_samples=config.raw_samples,
                    tkwargs=tkwargs,
                )
                inference_time = gp_fitting_time + acq_optimize_time
                x_next = unnormalize(x_next.detach(), bounds=x_bounds.transpose(0, 1))

            except Exception as e:
                log(f"Error at step {step}: {e}")
                break

        # Observe new values, update training points, compute hypervolume
        train_x, train_y, hv, regret = test_function.step(
            input_bounds=x_bounds,
            x_new=x_next,
            x_ctx=train_x,
            y_ctx=train_y,
            compute_hv=True,
            compute_regret=True,
            regret_type=config.regret_type,
        )

        hv = expand_metrics(hv, x_next)
        regret = expand_metrics(regret, x_next)

        hv_list.append(hv)
        regret_list.append(regret)
        time_list.append(inference_time)

        instantaneous_hv, _, _ = test_function.compute_hv(
            solutions=train_y[:, -config.num_candidates :]
        )
        instantaneous_hv = expand_metrics(instantaneous_hv, x_next)
        instantaneous_hv_list.append(instantaneous_hv)

        log(
            f"--- Step {step} ---"
            f"\nMetircs:"
            f"\n  Hypervolume: {hv}"
            f"\n  Regret: {regret}"
            f"\n  Cumulative time (s): {sum(time_list):.4f}"
            f"\nContext Points:"
            f"\n  train_x_tfm: {train_x_tfm}"
            f"\n  train_y_tfm: {train_y_tfm}"
            f"\n  Instantaneous hv: {instantaneous_hv}"
            # f"Step {step}: y_next: {train_y[:, -config.num_candidates:]}, HV: {hv}, Regret: {regret}, Time (s): {inference_time:.4f}"
        )

    # Log results: [batch_size, ]
    hv_T = hv_list[-1]
    regret_T = regret_list[-1]
    time_T = time_list[-1]

    line = (
        f"[results, seed={config.seed}]\n"
        f"hv_T:\t{hv_T}\n"
        f"regret_T:\t{regret_T}\n"
        f"time_T (s):\t{time_T}"
    )

    log(line)

    # Summarize results
    hv_stack = torch.cat(hv_list, dim=-1)  # [batch_size, T+1]
    regret_stack = torch.cat(regret_list, dim=-1)  # [batch_size, T+1]
    time_stack = torch.tensor(time_list)  # [T+1]
    instantaneous_hv_stack = torch.cat(instantaneous_hv_list, dim=-1)
    result_dict = {
        "hv": hv_stack.detach().cpu(),
        "regret": regret_stack.detach().cpu(),
        "time": time_stack.detach().cpu(),
        "instantaneous_hv": instantaneous_hv_stack.detach().cpu(),
    }

    # Save evaluation setup and results to data_save_path/res_filename.pt
    for key, val in result_dict.items():
        save_data(
            config=config, filename=key, data=val, override=config.override, log=log
        )

    if config.plot_enabled:
        fig = plot_optimization_batch(
            test_function=test_function,
            x_query=train_x,
            y_query=train_y,
        )
        save_fig(config=config, fig=fig, override=config.override, log=log)

        # Optionally plot pareto front
        front_fig = plot_fronts(
            function_name=config.function_name,
            dim=test_function.x_dim,
            solutions=train_y[0],  # [b]
        )
        if front_fig is not None:
            save_fig(
                config=config,
                fig=front_fig,
                override=config.override,
                log=log,
                filename="pareto_front",
            )


def _get_suffix(function_name, subfolder=None, extra=None):
    suffix = function_name
    if subfolder is not None:
        suffix = f"{suffix}/{subfolder}"
    if extra is not None:
        suffix = f"{suffix}/{extra}"
    return suffix


def save_data(
    config,
    filename: str,
    data: Tensor,
    override: bool = False,
    log: callable = print,
):
    path = get_result_data_path(
        model_name=config.acq_fn_name,
        expid=config.expid,
        task_type="optimization",
        suffix=_get_suffix(
            function_name=config.function_name,
            subfolder=config.subfolder,
            extra=config.suffix,
        ),
    )
    path = osp.join(path, str(config.seed))
    data_path = osp.join(path, f"{filename}.pt")

    if osp.exists(data_path) and not override:
        # Create a new folder with timestamp to save the new data.
        ts_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        new_path = osp.join(path, ts_str)
        data_path = osp.join(new_path, f"{filename}.pt")
        os.makedirs(new_path, exist_ok=True)
        log(f"\nFile already exists. Saving to new path: {data_path}")
    elif osp.exists(data_path) and override:
        # File exists and we override.
        os.makedirs(path, exist_ok=True)
        log(f"\nOverriding existing file at {data_path}")
    else:
        os.makedirs(path, exist_ok=True)
        log(f"\nFile does not exist. Saving to {data_path}")

    torch.save(data, data_path)

    # if osp.exists(data_path) and not override:
    #     ts_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    #     path = osp.join(path, ts_str)
    #     data_path = osp.join(path, f"{filename}.pt")
    #     log(f"\nFile already exists. New data will be saved to {data_path}")

    # global_dict["results"] = result_dict

    # log(f"Saving data to {data_path}\n")
    # os.makedirs(path, exist_ok=True)
    # torch.save(global_dict, data_path)

    return data_path


def save_fig(
    config, fig: Figure, override: bool = False, log: callable = print, filename=None
) -> str:
    path = get_result_plot_path(
        model_name=config.acq_fn_name,
        expid=config.expid,
        task_type="optimization",
        suffix=_get_suffix(
            function_name=config.function_name,
            subfolder=config.subfolder,
            extra=config.suffix,
        ),
    )
    path = osp.join(path, str(config.seed))

    filename = (
        params_to_filename(params=asdict(config)) if filename is None else filename
    )
    fig_path = osp.join(path, f"{filename}.png")

    if osp.exists(fig_path) and not override:
        ts_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        path = osp.join(path, ts_str)
        fig_path = osp.join(path, f"{filename}.png")
        log(f"\nFigure already exists. New figure will be saved to {fig_path}")
        return fig_path

    log(f"Saving figure to {fig_path}\n")
    os.makedirs(path, exist_ok=True)
    fig.savefig(fig_path)

    return fig_path


def preprocess_datapoints(test_function: TestFunction, train_x, train_y, tkwargs):
    """Preprocess datapoints.
    1. Remove batch dim
    2. Scale x and y
    3. Scale y to y_range
    3. negate y if needed
    """

    def _remove_batch_dim(tensor):
        if tensor.ndim == 3:
            assert tensor.shape[0] == 1, f"Expected batch size 1, got {tensor.shape[0]}"
            return tensor.squeeze(0)  # [n, d]

        assert tensor.ndim == 2, f"Expected 2D tensor, got {tensor.ndim}D"
        return tensor

    x_bounds = test_function.x_bounds.to(**tkwargs)  # [dx, 2]

    train_x = _remove_batch_dim(train_x).to(**tkwargs)  # [n, dx]
    train_y = _remove_batch_dim(train_y).to(**tkwargs)  # [n, dy]

    train_x_scaled = normalize(train_x, bounds=x_bounds.transpose(0, 1))
    train_y_scaled = test_function.scale_outputs(train_y, SCALE_Y_RANGE)

    if NEGATE:
        train_y_scaled = -train_y_scaled
    return train_x_scaled, train_y_scaled


def get_standard_bounds(dim: int, tkwargs):
    standard_bounds = torch.zeros((2, dim), **tkwargs)
    standard_bounds[1] = 1
    return standard_bounds


def optimize_acquisition_function(
    model: ModelListGP,
    train_x: Tensor,
    train_y: Tensor,
    standard_bounds: Tensor,
    ref_point: Tensor,
    sampler,
    acq_fn_name: str,
    num_candidates: int,
    num_restarts: int,
    raw_samples: int,
    tkwargs,
):
    """Optimize the acquisition function for next query points.
    Returns:
        x_best: Next query points [..., num_candidates, dx], inference_time (float)
    """
    num_objectives = train_y.shape[-1]

    t1 = time.time()
    if acq_fn_name == "qEHVI":
        with torch.no_grad():
            pred = model.posterior(train_x).mean  # [num_init, dy]
        partitioning = FastNondominatedPartitioning(ref_point=ref_point, Y=pred)

        acq_fn = qExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            partitioning=partitioning,
            sampler=sampler,
        )
    elif acq_fn_name == "qNEHVI":
        acq_fn = qNoisyExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            X_baseline=train_x,
            prune_baseline=True,
            sampler=sampler,
        )
    elif acq_fn_name == "qLogNEHVI":
        acq_fn = qLogNoisyExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            X_baseline=train_x,
            prune_baseline=True,
            sampler=sampler,
        )
    elif acq_fn_name == "qHVKG":
        acq_fn = qHypervolumeKnowledgeGradient(
            model=model,
            ref_point=ref_point,
            num_pareto=1,
            num_fantasies=1,  # Only for OilSorbent; to speed up
        )
    elif acq_fn_name == "JES":
        raise NotImplementedError
    elif acq_fn_name == "qNParEGO":
        with torch.no_grad():
            pred = model.posterior(train_x).mean  # [num_init, dy]

        acq_fn = []
        for _ in range(num_candidates):
            weights = sample_simplex(
                num_objectives,
                **tkwargs,
            ).squeeze()
            objective = GenericMCObjective(
                get_chebyshev_scalarization(weights=weights, Y=pred)
            )
            fn = qNoisyExpectedImprovement(  # pyre-ignore: [28]
                model=model,
                objective=objective,
                X_baseline=train_x,
                sampler=sampler,
                prune_baseline=True,
            )
            acq_fn.append(fn)

    # Optimize acqf_fn and return [q, dx]
    if isinstance(acq_fn, list):
        candidates, _ = optimize_acqf_list(
            acq_function_list=acq_fn,
            bounds=standard_bounds,
            num_restarts=num_restarts,
            raw_samples=raw_samples,  # used for intialization heuristic
            options={"batch_limit": BATCH_LIMIT, "maxiter": MAXITER},
        )
    else:
        candidates, _ = optimize_acqf(
            acq_function=acq_fn,
            bounds=standard_bounds,
            q=num_candidates,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            options={"batch_limit": BATCH_LIMIT, "maxiter": MAXITER},
            sequential=True,
        )
    t2 = time.time()
    inference_time = t2 - t1

    if candidates.ndim == 2:
        candidates = candidates.unsqueeze(0)
    return candidates, inference_time


def fit_gp(train_x: Tensor, train_y: Tensor, normalize_outputs: bool = True):
    """Define gp for MOO.

    Args: (train_x: [..., n, dx], train_y: [..., n, dy])
    Returns: (mll, model)
    """
    models = []
    batch_shape = train_y.shape[:-2]
    num_tasks = train_y.shape[-1]
    for i in range(num_tasks):
        train_y_i = train_y[..., i : i + 1]
        outcome_transform = None
        if normalize_outputs:
            outcome_transform = Standardize(m=1, batch_shape=batch_shape)

        models.append(
            SingleTaskGP(
                train_x,
                train_y_i,
                outcome_transform=outcome_transform,
            )
        )
    model = ModelListGP(*models)
    mll = SumMarginalLogLikelihood(model.likelihood, model)

    model.to(train_x)
    mll.to(train_x)
    return mll, model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--function_name", type=str, default="BraninCurrin")
    parser.add_argument("--acq_fn_name", type=str, default="qNEHVI")
    parser.add_argument("--num_candidates", type=int, default=1)
    parser.add_argument("--num_restarts", type=int, default=10)
    parser.add_argument("--raw_samples", type=int, default=512)
    parser.add_argument("--mc_samples", type=int, default=128)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--expid", type=str, default=time.strftime("%Y%m%d-%H%M%S"))
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--sigma", type=float, default=0.0)
    parser.add_argument("--num_initial_points", type=int, default=1)
    parser.add_argument("--regret_type", type=str, default="ratio")
    parser.add_argument("--T", type=int, default=100)
    parser.add_argument("--normalize_outputs", type=bool, default=True)
    parser.add_argument("--plot_enabled", type=bool, default=PLOT_ENABLED)
    parser.add_argument("--dim", type=int, default=None)
    parser.add_argument("--override", type=bool, default=True)
    parser.add_argument("--subfolder", type=str, default=None)
    parser.add_argument("--suffix", type=str, default=None)

    args = parser.parse_args()
    arg_dict = vars(args)

    config = GPBaselineConfig(**arg_dict)
    main(config)
