# =============================================================================
# Runner
# =============================================================================

import os
import argparse
from pathlib import Path
from types import SimpleNamespace

import yaml
import torch
import botorch
from botorch.utils.transforms import unnormalize

from acquisitions import get_acqf
from problems import get_problem
from utils import (
    BiLevelBODataset, ExperimentLogger, RFFHybridModel, RFFModelList
)
from utils.models import ExactGPModel, ExactModelList


# -----------------------------------------------------------------------------
# Bi-Level BO Runner
# -----------------------------------------------------------------------------

class BORunner:

    def __init__(
        self,
        config: SimpleNamespace,
        local_path: Path,
    ) -> None:

        self.config = config
        self.local_path = local_path
        self._set_seed(0)
        self.logger = ExperimentLogger(config, local_path)


    def _set_seed(
        self,
        seed: int,
    ) -> None:

        os.environ["PYTHONHASHSEED"] = str(seed)
        torch.manual_seed(seed)
        botorch.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


    def run(self) -> None:

        self.logger.start()
        self.logger.info("Generating initial data ...")

        ### Generate initial data ---------------------------------------------
        problem = get_problem(**self.config.problem)
        dataset = BiLevelBODataset(
            num_dims=problem.num_dims,
            num_objectives=problem.num_objectives,
            num_constraints=problem.num_constraints,
        )
        n_init = self.config.n_init
        if problem.has_candidates:
            mask_padding = problem.candidates.isnan().any(dim=-1)
            valid_idx = (~mask_padding).nonzero(as_tuple=False)
            indices = valid_idx[torch.randperm(valid_idx.size(0))[:n_init]]
            init_X = problem.candidates[*indices.chunk(2, dim=1)]
            problem.mask_evaluated[*indices.chunk(2, dim=1), :] = True
        else:
            bounds = torch.tensor(problem.bounds, dtype=torch.double).T
            init_X = torch.rand(self.config.n_init, bounds.size(-1))
            init_X = unnormalize(init_X, bounds=bounds)
        init_X = init_X.view(self.config.n_init, sum(problem.num_dims))
        init_outputs = problem(init_X, noise=self.config.noise)
        dataset.add(init_X, init_outputs, metadata={"init": True})
        ### -------------------------------------------------------------------

        self.logger.obs(dataset.data)

        self._set_seed(self.config.seed)
        # Bayesian Optimization Loop
        n_iter = self.config.n_iter
        for it in range(1, n_iter+1):
            self.logger.info("")
            self.logger.info(f"<< Iteration {it:3d}/{n_iter:3d} >>")
            self.logger.info("")
            self.logger.info("Fitting the model ...")

            ### Fit model -----------------------------------------------------
            train_data = dataset.get()
            rff_models = []
            # exact_models = []
            for train_X, train_Y in train_data:
                model = RFFHybridModel(train_X, train_Y, **self.config.model)
                rff_models.append(model)
                # model_exact = ExactGPModel(train_X, train_Y, **self.config.model)
                # exact_models.append(model)
            model = RFFModelList(*rff_models)
            model.fit()
            # model_exact = ExactModelList(*exact_models)
            # model_exact.fit()
            """
            x = torch.linspace(0, 1, 100)
            y = torch.linspace(0, 1, 100)
            grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
            print(grid.shape)
            mean_rff = model.mean(grid).detach()
            mean_exact = model_exact.mean(grid).detach()
            import matplotlib.pyplot as plt
            plt.figure(figsize=(6, 5))
            c = plt.contourf(grid[..., 0], grid[..., 1], mean_rff[..., 0], levels=50)
            plt.colorbar(c)
            plt.savefig("mean_rff0.png")
            plt.clf();plt.close()
            plt.figure(figsize=(6,5))
            c = plt.contourf(grid[..., 0], grid[..., 1], mean_rff[..., 1], levels=50)
            plt.colorbar(c)
            plt.savefig("mean_rff1.png")
            plt.figure(figsize=(6, 5))
            c = plt.contourf(grid[..., 0], grid[..., 1], mean_exact[..., 0], levels=50)
            plt.colorbar(c)
            plt.savefig("mean_exact0.png")
            plt.clf();plt.close()
            plt.figure(figsize=(6,5))
            c = plt.contourf(grid[..., 0], grid[..., 1], mean_exact[..., 1], levels=50)
            plt.colorbar(c)
            plt.savefig("mean_exact1.png")

            return
            """
            num_objectives = problem.num_objectives
            num_constraints = problem.num_constraints
            obj_models = model.rff_models[:sum(num_objectives)]
            con_models = model.rff_models[sum(num_objectives):]
            model_Y_upper = RFFModelList(*obj_models[:num_objectives[0]])
            model_Y_lower = RFFModelList(*obj_models[num_objectives[0]:])
            models = [model_Y_upper, model_Y_lower]
            model_C_upper, model_C_lower = None, None
            if num_constraints[0] > 0:
                model_C_upper = RFFModelList(*con_models[:num_constraints[0]])
                models.append(model_C_upper)
            if num_constraints[1] > 0:
                model_C_lower = RFFModelList(*con_models[num_constraints[0]:])
                models.append(model_C_lower)
            ### ---------------------------------------------------------------

            self.logger.model(models)
            self.logger.info("Acquiring next observation point ...")

            ### Get next data ------------------------------------------------
            acqf = get_acqf(
                **self.config.acquisition,
                num_dims=problem.num_dims,
                model_Y_upper=model_Y_upper,
                model_Y_lower=model_Y_lower,
                model_C_upper=model_C_upper,
                model_C_lower=model_C_lower,
            )
            if problem.has_candidates:
                indices, Y_mask = acqf.optimize_pool(
                    candidates=problem.candidates,
                    mask_evaluated=problem.mask_evaluated,
                    decoupled=self.config.decoupled,
                )
                indices = indices.chunk(2, dim=0)
                next_X = problem.candidates[*indices]
                problem.mask_evaluated[*indices, :] |= Y_mask
            else:
                next_X, Y_mask = acqf.optimize_query(
                    bounds=bounds,
                    decoupled=self.config.decoupled,
                )
            outputs = problem(next_X, noise=self.config.noise).squeeze()
            next_outputs = torch.where(Y_mask, outputs, float("nan"))
            dataset.add(next_X, next_outputs, metadata={"init": False})
            ### --------------------------------------------------------------
            
            self.logger.obs(dataset.data[-1:])

        self.logger.info("")
        self.logger.info("Saving results ...")
        dataset.save(self.local_path / "dataset.pt")
        self.logger.end()



if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True)
    parser.add_argument("--name", "-n", type=str, required=True)
    parser.add_argument("--seed", "-s", type=int, required=True)
    args = parser.parse_args()

    config_path = Path("experiments/configs") / f"{args.config}.yaml"
    result_path = Path("experiments/results") / args.config
    with open(config_path, mode="r") as f:
        cfg = yaml.safe_load(f)
    global_cfg, local_cfg = cfg["global"], cfg["local"][args.name]
    local_cfg["seed"] = args.seed
    config = SimpleNamespace({**global_cfg, **local_cfg})
    local_path = result_path / config.name / str(config.seed)
    os.makedirs(local_path, exist_ok=True)

    runner = BORunner(config=config, local_path=local_path)
    runner.run()