"""Evaluation function environment.

NOTE assume minimization by default.

Supports:
    - sampling
    - evaluation
    - reward / regret computation

Classes:
    - TestFunction: Base class
    - BenchmarkFunction: Continuous synthetic benchmark function environment
    - IntepolatorFunction: Interpolator function environment
"""

import torch
from torch import Tensor
import numpy as np
from botorch.test_functions import (
    Branin,
    BraninCurrin,
    EggHolder,
    ZDT1,
    ZDT2,
    ZDT3,
)
from data.function_preprocessing import (
    tuple_list_to_nested_list,
    make_range_nested_list,
    make_range_tensor,
    scale_between_bounds,
    estimate_objective_bounds,
    BatchNDInterpolatorExtTensor,
)
from data.function_sampling import sample_full_inputs_from_subspaces
from data.botorch_utils import (
    Forrester,
    AckleyRosenbrock,
    AckleyRastrigin,
    OilSorbent,
    OilSorbentContinuousMid,
)
from data.data_masking import (
    take_along_valid_dims,
    restore_full_dims,
)
from utils.moo import norm_compute_hv_batch
from utils.types import FloatListOrNestedOrTensor, NestedFloatList, FloatListOrNested
from typing import List, Dict, Tuple, Optional
from einops import repeat

SIGMA = 0.0

MO_BENCHMARK = {
    "BraninCurrin": BraninCurrin,
    "AckleyRosenbrock": AckleyRosenbrock,
    "AckleyRastrigin": AckleyRastrigin,
    "ZDT1": ZDT1,
    "ZDT2": ZDT2,
    "ZDT3": ZDT3,
}
MO_REALWORLD = {
    "OilSorbentContinuousMid": OilSorbentContinuousMid,
    "OilSorbent": OilSorbent,
}
SO_BENCHMARK = {"Branin": Branin, "Forrester": Forrester, "EggHolder": EggHolder}
TESTFUNCTIONS = {**MO_BENCHMARK, **SO_BENCHMARK, **MO_REALWORLD}

SO_Y_BOUNDS = {
    "Forrester": [[-6.020740, 16.0]],
    "Branin": [[0.397887, 309.0]],
    "Currin": [[1.18, 14.0]],
    "Ackley": [[0.0, 23.0]],
    "Rosenbrock": [[0.0, 3907.0]],
    "Rastrigin": [[0.0, 81.0]],
    "EggHolder": [[-959.6407, 1050.0]],
}

MO_Y_BOUNDS = {
    "AckleyRastrigin": [SO_Y_BOUNDS["Ackley"][0], SO_Y_BOUNDS["Rastrigin"][0]],
    "AckleyRosenbrock": [SO_Y_BOUNDS["Ackley"][0], SO_Y_BOUNDS["Rosenbrock"][0]],
    "BraninCurrin": [SO_Y_BOUNDS["Branin"][0], SO_Y_BOUNDS["Currin"][0]],
}
RW_Y_BOUNDS = {
    # "OilSorbentContinuousMid": [
    #     [-170.0, -70.0],
    #     [-235.0, -10.0],
    #     [-10.0, 15.0],
    # ]  # Very loose approx bounds
}
BENCHMARK_Y_BOUNDS = {**SO_Y_BOUNDS, **MO_Y_BOUNDS, **RW_Y_BOUNDS}


def get_ref_point(bounds: Tensor | NestedFloatList, candidates: Tensor = None) -> List:
    """Define reference point.
    - If candidate is provided, use the maximum objective values in candidate
    - Otherwise, use upper bounds

    Args:
        bounds: [dy, 2] or dy x [y_min, y_max]
        candidate: [n, dy]

    Returns:
        ref_point: [dy]
    """
    if candidates is None:
        return [bounds[i][1] for i in range(len(bounds))]
    else:
        assert candidates.shape[-1] == len(bounds)
        candidate_max = candidates.max(dim=0).values  # [dy]
        return candidate_max.tolist()


def get_max_hv(ref_point: List, bounds: NestedFloatList, candidates: Tensor = None):
    """Define max. hypervolume.
    - If candidate is provided, use the minimum objective values in candidate
    - Otherwise, use lower bounds

    Args:
        ref_point: [dy]
        bounds: [dy, 2] or dy x [y_min, y_max]
        candidate: [n, dy]

    Returns:
        max_hv: float
    """
    if candidates is None:
        return np.prod([ref_point[i] - bounds[i][0] for i in range(len(bounds))])
    else:
        assert candidates.shape[-1] == len(bounds)
        candidates_min = candidates.min(dim=0).values  # [dy]
        return np.prod(
            (np.array(ref_point) - candidates_min.cpu().numpy()).clip(min=0.0)
        )


class TestFunction:
    """Test function environment base class;
    `get_metadata()` must be implemented in subclasses.

    Attrs:
        function_name (sr): Function name
        func (callable): Callable function instance
        x_dim (int): Dimension of the input space
        y_dim (int): Dimension of the output space
        x_bounds (Tensor): Input bounds in the truth domain, shape [dx, 2]
        y_bounds (Tensor): Output bounds in the truth domain, shape [dy, 2]
        ref_point (Tensor): Reference point in the truth domain, shape [dy]
        max_hv (float): Maximum hypervolume value from `ref_point` in the truth domain
        sigma (float): Noise level for observations
    """

    func: callable = None
    x_dim: int = None
    y_dim: int = None
    x_bounds: Tensor = None
    y_bounds: Tensor = None
    ref_point: Tensor = None
    max_hv: float = None
    sigma: float = SIGMA
    function_name: str = "unknown_function"

    def __init__(self, **kwargs):
        metadata = self.get_metadata(**kwargs)
        self.init_from_metadata(metadata)

    def get_metadata(self, **kwargs) -> Dict:
        """NOTE Get a dictionary of metadata of the function."""
        raise NotImplementedError(f"{self.__class__.__name__} must implement this.")

    def init_from_metadata(self, metadata: Dict) -> List:
        """Initialize function environment from metadata."""

        def _get_required(key: str):
            val = metadata.get(key)
            if val is None:
                raise ValueError(f"Function metadata must contain '{key}'.")
            return val

        self.x_bounds = torch.tensor(_get_required("x_bounds"))
        self.y_bounds = torch.tensor(_get_required("y_bounds"))
        self.ref_point = torch.tensor(_get_required("ref_point"))
        self.func = _get_required("func")
        self.max_hv = _get_required("max_hv")

        self.function_name = metadata.get("function_name", self.function_name)
        self.sigma = metadata.get("sigma", self.sigma)

        self.x_dim = len(self.x_bounds)
        self.y_dim = len(self.y_bounds)

    def get_max_hv(self) -> float:
        """Get maximum hypervolume."""
        return self.max_hv

    def scale_inputs(
        self, inputs: Tensor, input_bounds: FloatListOrNestedOrTensor
    ) -> Tensor:
        """Scale inputs from its original domain (`input_bounds`) to function input domain (`x_bounds`).

        Args:
            inputs: [..., DX]
            input_bounds: [dx, 2] or [2]

        Returns:
            scaled_inputs: [..., DX]
        """
        return scale_between_bounds(
            data=inputs, inp_bounds=input_bounds, out_bounds=self.x_bounds, sigma=0.0
        )

    def scale_outputs(
        self,
        outputs: Tensor,
        output_bounds: FloatListOrNestedOrTensor,
        sigma: Optional[float] = None,
    ) -> Tensor:
        """Scale outputs from function output domain (`y_bounds`) to target domain (`output_bounds`).

        Args:
            outputs: [..., DY]
            output_bounds: [dy, 2] or [2]
            sigma: Optional noise level, defaults to self.sigma

        Returns:
            scaled_outputs [..., DY]
        """
        if sigma is None:
            sigma = self.sigma
        else:
            assert sigma >= 0.0, "`sigma` must be non-negative."

        return scale_between_bounds(
            data=outputs,
            inp_bounds=self.y_bounds,
            out_bounds=output_bounds,
            sigma=sigma,
        )

    @staticmethod
    def _update_context(new: Tensor, old: Optional[Tensor]) -> Tensor:
        """Update context with new data points.
        [B, num_old | None, DY] -> [B, num_old + num_new | num_new, DY]
        """
        if old is None:
            return new.clone()
        else:
            B, _, DY = old.shape
            assert (
                new.shape[0] == B and new.shape[2] == DY
            ), f"{new.shape} != {old.shape}"
            return torch.cat((old, new), dim=1)

    def _sample(
        self,
        num_subspace_points: int,
        input_bounds: FloatListOrNestedOrTensor,
        use_grid_sampling: bool,
        use_factorized_policy: bool = False,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        x_mask: Optional[Tensor] = None,
        y_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        if x_mask is None:
            x_mask = torch.ones((self.x_dim,), dtype=torch.bool, device=device)

        # Sample input points in the subspace
        x, chunks, chunk_mask = sample_full_inputs_from_subspaces(
            d=num_subspace_points,
            x_mask=x_mask,
            input_bounds=input_bounds,
            use_grid_sampling=use_grid_sampling,
            use_factorized_policy=use_factorized_policy,
        )

        # Evaluate at x: [m, dy]
        y = self.evaluate(x=x, input_bounds=input_bounds, x_mask=x_mask, y_mask=y_mask)

        return (x, y, chunks, chunk_mask)

    def sample(
        self,
        input_bounds: FloatListOrNestedOrTensor,
        num_subspace_points: int,
        use_grid_sampling: bool,
        batch_size: int,
        use_factorized_policy: bool = False,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        x_mask: Optional[Tensor] = None,
        y_mask: Optional[Tensor] = None,
    ):
        """Sample data batch on the function.

        Args:
            input_bounds: Bounds for input space
            num_subspace_points: Number of points for a subspace
            use_grid_sampling: Sample from a grid if True
            batch_size (B): Batch size
            use_factorized_policy: Factorize search space if True
            device: Defaults to 'cuda' if available, else 'cpu'
            x_mask: Optional mask for valid x dims, [dx_max]
            y_mask: Optional mask for valid y dims, [dy_max]

        Returns:
            batch_x: [B, M, dx], batch_y: [B, M, dy],
            batch_chunks: [B, d, dx], batch_chunk_mask: [B, n_chunks, dx]
        """
        assert batch_size > 0, "`batch_size` must be a positive integer."

        x_list, y_list, chunks_list = [], [], []
        for _ in range(batch_size):
            x, y, chunks, chunk_mask = self._sample(
                num_subspace_points=num_subspace_points,
                input_bounds=input_bounds,
                use_grid_sampling=use_grid_sampling,
                use_factorized_policy=use_factorized_policy,
                device=device,
                x_mask=x_mask,
                y_mask=y_mask,
            )

            x_list.append(x)
            y_list.append(y)
            chunks_list.append(chunks)

        # Stack results to create batch
        batch_x = torch.stack(x_list, dim=0)
        batch_y = torch.stack(y_list, dim=0)
        batch_chunks = torch.stack(chunks_list, dim=0)

        # NOTE For efficiency: chunk_mask is shared in batch so can just expand it
        batch_chunk_mask = repeat(chunk_mask, "n d -> b n d", b=batch_size)

        return (batch_x, batch_y, batch_chunks, batch_chunk_mask)

    def evaluate(
        self,
        x: Tensor,
        input_bounds: FloatListOrNestedOrTensor,
        x_mask: Optional[Tensor] = None,
        y_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Evaluate at x.

        Args:
            x: x in bounds `input_bounds`, [m, dx_max]
            input_bounds: bounds of `x`, [dx_max, 2]
            x_mask: Optional mask for valid x dims, [dx_max]
            y_mask: Optional mask for valid y dims, [dy_max]

        Returns:
            y: Function values at x, [m, dy_max]
        """
        if x_mask is None and y_mask is None:
            # No mask provided, return directly
            return self.__call__(x=x, input_bounds=input_bounds)
        else:
            # NOTE x_mask / y_mask can be None
            # None check in `take_along_valid_dims` and `restore_full_dims`

            # Take valid dimensions of x
            x_valid = take_along_valid_dims(data=x, mask=x_mask, dim=-1)
            num_dim = x_valid.shape[-1]

            # Take valid dimensions of input bounds
            input_bounds = make_range_tensor(input_bounds, num_dim=num_dim).to(x.device)
            input_bounds_valid = take_along_valid_dims(
                data=input_bounds, mask=x_mask, dim=0
            )

            # Evaluate on valid x
            y_valid = self.__call__(x=x_valid, input_bounds=input_bounds_valid)

            # Restore full dimensions of y
            y = restore_full_dims(data=y_valid, mask=y_mask, dim=-1)

            del x_valid, y_valid, input_bounds_valid

            return y

    def compute_hv(
        self,
        solutions: Tensor,
        normalize: bool = False,
        y_mask: Optional[Tensor] = None,
    ) -> np.ndarray:
        """Compute hypervolume from `solutions`.

        Args:
            solutions: shape [B, N, dy_max]
            normalize: Whether to normalize solutions and reference point before hypervolume computation
            y_mask: Optional mask for valid y dims, [dy_max]

        Returns:
            reward: shape [B]
            reward_sol: (Optionally normalized) solutions used for reward computation, shape [B, N, dy_valid]
            reward_ref_points: (Optionally normalized) reference points used for reward computation, shape [B, dy_valid]
        """
        bounds = self.y_bounds.to(solutions)  # [dy_valid, 2]
        ref_point = self.ref_point.to(solutions)  # [dy_valid]

        # If y_mask is provided, restore full dims of bounds and ref_point
        if y_mask is not None:
            dy_valid = bounds.shape[0]
            dy_max = y_mask.shape[0]

            if dy_valid != dy_max:
                # Restore full dims of bounds and ref_point
                bounds = restore_full_dims(data=bounds, mask=y_mask, dim=0)
                ref_point = restore_full_dims(data=ref_point, mask=y_mask, dim=0)

        mins = bounds[:, 0]
        maxs = bounds[:, 1]

        reward, reward_solutions, reward_ref_points = norm_compute_hv_batch(
            solutions=solutions,
            minimum=mins,
            maximum=maxs,
            ref_point=ref_point,
            y_mask=y_mask,
            normalize=normalize,
        )

        return reward, reward_solutions, reward_ref_points

    def compute_regret(
        self,
        solutions: Tensor,
        solution_candidate_set: Optional[Tensor] = None,
        regret_type: str = "ratio",
        y_mask: Optional[Tensor] = None,
    ) -> np.ndarray:
        """Compute regret from `solutions`.

        Args:
            solutions: shape [B, N, dy_max]
            regret_type: in ["value", "ratio", "norm_ratio"]
            solution_candidate_set: Optional set of candidate solutions, required for `norm_ratio` regret type.
            y_mask: Optional mask for valid y dims, [dy_max]

        Returns:
            regret: shape [B]
        """
        if regret_type in ["value", "ratio"]:
            hv_np = self.compute_hv(solutions, normalize=False, y_mask=y_mask)[0]
            regret_np = self.max_hv - hv_np

            if regret_type == "ratio":
                # Avoid division by zero
                norm_term = self.max_hv if self.max_hv != 0 else 1.0
                regret_np /= norm_term

        elif regret_type in ["norm_ratio", "norm_ratio_dim"]:
            if solution_candidate_set is None:
                raise ValueError(
                    f"`solution_candidate_set` must be provided for `norm_ratio` regret type"
                )

            hv_np = self.compute_hv(solutions, normalize=True, y_mask=y_mask)[0]
            max_hv_on_set_np = self.compute_hv(
                solution_candidate_set, normalize=True, y_mask=y_mask
            )[0]

            regret_np = (max_hv_on_set_np - hv_np) / max_hv_on_set_np
            if regret_type == "norm_ratio_dim":
                raise NotImplementedError(f"Taking sqrt?...")
        elif regret_type == "simple":
            assert (
                self.y_dim == 1
            ), "`simple` regret only supports single-objective functions."
            regret = solutions.min(dim=1).values  # [B, dy_max]
            regret = take_along_valid_dims(
                data=regret, mask=y_mask, dim=-1
            )  # [B, dy_valid]
            assert regret.shape[1] == 1

            if hasattr(self.func, "_optimal_value"):
                regret -= self.func._optimal_value
            else:
                regret -= self.y_bounds[0][0]
            regret_np = regret.squeeze(-1).cpu().numpy()  # [B,]
        else:
            raise ValueError(f"Unknown regret type: {regret_type}")

        return regret_np

    def step(
        self,
        x_new: Tensor,
        input_bounds: FloatListOrNestedOrTensor,
        x_ctx: Optional[Tensor] = None,
        y_ctx: Optional[Tensor] = None,
        x_mask: Optional[Tensor] = None,
        y_mask: Optional[Tensor] = None,
        compute_hv: bool = True,
        compute_regret: bool = True,
        regret_type: str = "ratio",
        solution_candidate_set: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Optional[np.ndarray], Optional[np.ndarray]]:
        """Evaluate function input `x_new`, update context, optionally compute reward / regret.

        Args:
            x_new: Input points, shape [B, num_new, dx_max]
            input_bounds: Input bounds for scaling
            x_ctx: Optional context input points, shape [B, num_ctx, dx_max]
            y_ctx: Optional context output points, shape [B, num_ctx, dy_max]
            x_mask: Optional mask for valid x dims, [dx_max]
            y_mask: Optional mask for valid y dims, [dy_max]
            compute_reward: compute reward from choosing `x_new` if True
            compute_regret: compute regret from choosing `x_new` if True
            regret_type: Type of regret to compute, defaults to "ratio"
            solution_candidate_set: Optional set of candidate solutions for regret computation,

        Returns:
            x_ctx: Updated context input points, shape [B, num_ctx + num_new, dx_max]
            y_ctx: Updated context output points, shape [B, num_ctx + num_new, dy_max]
            reward (np.ndarray): Reward from choosing `x_new`, shape [B] or None if not computed
            regret (np.ndarray): Regret from choosing `x_new`, shape [B] or None if not computed
        """
        # Evaluate at x_new: [B, num_new, dy_max]
        y_new = self.evaluate(
            x=x_new, input_bounds=input_bounds, x_mask=x_mask, y_mask=y_mask
        )

        # Update context
        x_ctx = TestFunction._update_context(new=x_new, old=x_ctx)
        y_ctx = TestFunction._update_context(new=y_new, old=y_ctx)

        if compute_hv:
            reward = self.compute_hv(solutions=y_ctx, y_mask=y_mask)[0]
        else:
            reward = None

        if compute_regret:
            regret = self.compute_regret(
                y_ctx,
                regret_type=regret_type,
                y_mask=y_mask,
                solution_candidate_set=solution_candidate_set,
            )
        else:
            regret = None

        return x_ctx, y_ctx, reward, regret

    def init(
        self,
        input_bounds: FloatListOrNestedOrTensor,
        batch_size: int,
        num_initial_points: int,
        regret_type: str,
        compute_hv: bool = True,
        compute_regret: bool = True,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        x_mask: Optional[Tensor] = None,
        y_mask: Optional[Tensor] = None,
        solution_candidate_set: Optional[Tensor] = None,
    ):
        """Sample initial samples with x_ctx scaled to `input_bounds`.

        Returns:
            x_ctx: [B, num_initial_points, dx_max]
            y_ctx: [B, num_initial_points, dy_max]
            reward (np.ndarray): [B] or None
            regret (np.ndarray): [B] or None
        """
        if x_mask is None:
            # All-ones mask for x dims
            x_mask = torch.ones((self.x_dim,), dtype=torch.bool, device=device)

        x_init = sample_full_inputs_from_subspaces(
            d=num_initial_points,
            x_mask=x_mask,
            input_bounds=input_bounds,
            use_grid_sampling=True,  # Use grid sampling for initial points
            use_factorized_policy=False,
        )[0]

        # Repeat across batch
        x_init = repeat(x_init, "n dx -> b n dx", b=batch_size)

        # Get initial samples, reward and regret
        x_ctx, y_ctx, reward, regret = self.step(
            x_new=x_init,
            input_bounds=input_bounds,
            x_ctx=None,
            y_ctx=None,
            x_mask=x_mask,
            y_mask=y_mask,
            compute_hv=compute_hv,
            compute_regret=compute_regret,
            regret_type=regret_type,
            solution_candidate_set=solution_candidate_set,
        )

        return x_ctx, y_ctx, reward, regret

    def __call__(self, x: Tensor, input_bounds: FloatListOrNestedOrTensor) -> Tensor:
        """Function forward pass: scale x, evaluate function, and return y.

        Args:
            x: Input points, shape [..., dx_valid]
            input_bounds: Input bounds for scaling

        Returns:
            y: Function values, [..., dy_valid]
        """
        # Add batch dim and reduce later if x has no batch dim
        reduce_batch_dim = x.ndim == 2
        if reduce_batch_dim:
            x = x.unsqueeze(0)

        B, N, DX = x.shape
        assert DX == self.x_dim, f"Input dimension mismatch: {DX} != {self.x_dim}"

        # Scale x from input_bounds to x_bounds: [B, N, DX]
        x_scaled = self.scale_inputs(inputs=x, input_bounds=input_bounds)

        # Evaluate function at x_scaled: [B, N, DY]
        y = self.func(x_scaled).view(B, N, self.y_dim)

        if reduce_batch_dim:
            y = y.squeeze(0)

        return y


class BenchmarkFunction(TestFunction):
    """Synthetic benchmark function environment based on botorch implementation.

    Args:
        function_name: Synthetic function name
        sigma: Noise level for function value observations, default to 0.0
    """

    def __init__(
        self,
        function_name: str,
        sigma: Optional[float] = None,
        **kwargs,
    ):
        super().__init__(function_name=function_name, sigma=sigma, **kwargs)

    @staticmethod
    def get_function_constructor(function_name: str) -> bool:
        """Get function constructor from SYNTHETIC_BENCHMARK by name."""
        func_constructor = TESTFUNCTIONS.get(function_name)
        return func_constructor

    def get_metadata(
        self,
        function_name: str,
        sigma: Optional[float] = None,
        x_range_list: Optional[FloatListOrNested] = None,
        **kwargs,
    ) -> Dict:
        """Get synthetic benchmark function and related metadata by name.

        Args:
            function_name: Name of the synthetic function
            sigma: Optional noise level
            x_range_list: Optional x_range_list, in case we don't want to use the default bounds

        Returns:
            None: If function is not found
            Dict: A dictionary with the following keys:
                func: botorch function instance
                x_bounds (List): Input bounds for the function, DX x [[x_min, x_max]]
                y_bounds (List): Output bounds for the function, DY x [[y_min, y_max]]
                ref_point (List): botorch defined if attribute found, otherwise upper bounds
                max_hv (float): botorch defined if attribute found, otherwise computed from reference point and lower bounds
        """
        # Get function instance
        func_constructor = self.get_function_constructor(function_name)

        if func_constructor is None:
            raise ValueError(
                "Function not found in SYNTHETIC_BENCHMARK."
                f"Only {list(TESTFUNCTIONS.keys())} are available."
            )

        if function_name in ["ZDT1", "ZDT2", "ZDT3"]:
            # NOTE only use 2D input for ZDT functions
            dim = kwargs.get("dim", None)
            dim = 2 if dim is None else dim
            func = func_constructor(negate=False, dim=dim)
        else:
            func = func_constructor(negate=False)

        # Prepare bounds: dim x [[min, max]]
        x_bounds = tuple_list_to_nested_list(func._bounds)
        x_dim = len(x_bounds)

        if x_range_list is not None:
            # Use provided x_range_list
            x_bounds = make_range_nested_list(range_list=x_range_list, num_dim=x_dim)

        y_bounds = BENCHMARK_Y_BOUNDS.get(function_name, None)
        if y_bounds is None:
            num_objectives = getattr(func, "num_objectives", 1)
            y_bounds = estimate_objective_bounds(
                func=func,
                num_objectives=num_objectives,
                x_bounds=x_bounds,
            )
            BENCHMARK_Y_BOUNDS[function_name] = y_bounds  # Cache for future use

        # Check if function is multi-objective or single-objective
        # num_objectives = getattr(func, "num_objectives", 1)

        # y_bounds = estimate_objective_bounds(
        #     func=func,
        #     num_objectives=num_objectives,
        #     x_bounds=x_bounds,
        # )

        # if hasattr(func, "_optimal_value") and x_range_list is None:
        #     # Sanity check: y_min found by optimization should match botorch-defined optimal value in default x bounds
        #     # y_min = y_bounds[0][0]
        #     # if not np.isclose(y_min, func._optimal_value, atol=1e-6):
        #     #     raise ValueError(f"Expected y_min {func._optimal_value}, got {y_min}")
        #     # Set y_min to botorch-defined optimal value for consistency
        #     y_bounds[0][0] = func._optimal_value

        ref_point = getattr(func, "_ref_point", None)
        if ref_point is None:
            ref_point = get_ref_point(y_bounds)  # [dy]

        max_hv = getattr(func, "_max_hv", None)
        if max_hv is None:
            max_hv = get_max_hv(ref_point, y_bounds)

        results = {
            "function_name": function_name,
            "func": func,
            "x_bounds": x_bounds,
            "y_bounds": y_bounds,
            "ref_point": ref_point,
            "max_hv": max_hv,
        }

        if sigma is not None:
            results["sigma"] = sigma

        return results


class IntepolatorFunction(TestFunction):
    """Interpolator function environment based on training data.

    Args:
        function_name: Name of the function - only for identification purposes
        train_x: Training input points, shape [n, dx]
        train_y: Training output points, shape [n, dy]
        train_x_bounds: Input bounds for training data, shape [dx, 2]
        train_y_bounds: Output bounds for training data, shape [dy, 2]
        sigma: Noise level for function value observations, default to 0.0
    """

    def __init__(
        self,
        train_x: Tensor,
        train_y: Tensor,
        train_x_bounds: FloatListOrNested,
        train_y_bounds: FloatListOrNested,
        function_name: Optional[str] = None,
        sigma: Optional[float] = None,
        **kwargs,
    ):
        super().__init__(
            train_x=train_x,
            train_y=train_y,
            train_x_bounds=train_x_bounds,
            train_y_bounds=train_y_bounds,
            function_name=function_name,
            sigma=sigma,
            **kwargs,
        )

    @staticmethod
    def is_valid_input(train_x, train_y, train_x_bounds, train_y_bounds) -> bool:
        if train_x is None or train_y is None:
            raise ValueError("Training data must be provided.")

        if train_x_bounds is None or train_y_bounds is None:
            raise ValueError("Training bounds must be provided.")

        if train_x.ndim != 2 or train_y.ndim != 2:
            raise ValueError(f"Training data must be 2D Tensors. ")

        if train_x.shape[:-1] != train_y.shape[:-1]:
            raise ValueError(
                f"Data shapes mismatch: {train_x.shape[:-1]} != {train_y.shape[:-1]}"
            )

    def get_metadata(
        self,
        train_x: Tensor,
        train_y: Tensor,
        train_x_bounds: FloatListOrNestedOrTensor,
        train_y_bounds: FloatListOrNestedOrTensor,
        function_name: Optional[str] = None,
        sigma: Optional[float] = None,
        **kwargs,
    ) -> Dict:
        """Build interpolator function and related metadata from training data.

        Args:
            train_x: Input training points, shape [n, dx]
            train_y: Output training values, shape [n, dy]
            train_x_bounds: Input bounds for training data, DX x [[x_min, x_max]]
            train_y_bounds: Output bounds for training data, DY x [[y_min, y_max]]

        Returns: A dictionary with the following keys:
            func: BatchNDInterpolatorExtTensor instance for interpolation
            x_bounds (List) = train_x_bounds
            y_bounds (List) = train_y_bounds
            ref_point (List): Upper bounds
            max_hv (float): Maximum hypervolume computed from reference point and lower bounds
            function_name (str): Optional name
                for the function, defaults to "interpolator_function"
            sigma (float): Optional noise level for observations
        """
        self.is_valid_input(train_x, train_y, train_x_bounds, train_y_bounds)

        x_dim = train_x.shape[-1]
        y_dim = train_y.shape[-1]
        train_x_bounds = make_range_nested_list(train_x_bounds, num_dim=x_dim)
        train_y_bounds = make_range_nested_list(train_y_bounds, num_dim=y_dim)

        # Define interpolator function that can take batch inputs: [..., DX] -> [..., DY]
        func = BatchNDInterpolatorExtTensor(points=train_x, values=train_y)

        ref_point = get_ref_point(bounds=train_y_bounds, candidates=train_y)  # [dy]
        max_hv = get_max_hv(
            ref_point=ref_point, bounds=train_y_bounds, candidates=train_y
        )

        results = {
            "func": func,
            "x_bounds": train_x_bounds,
            "y_bounds": train_y_bounds,
            "ref_point": ref_point,
            "max_hv": max_hv,
        }

        results["function_name"] = (
            function_name if function_name else "interpolator_function"
        )
        if sigma is not None:
            results["sigma"] = sigma

        return results


def get_function_environment(
    function_name: str,
    sigma: float = SIGMA,
    train_x: Optional[Tensor] = None,
    train_y: Optional[Tensor] = None,
    train_x_bounds: Optional[FloatListOrNested] = None,
    train_y_bounds: Optional[FloatListOrNested] = None,
    dim: Optional[int] = None,
) -> TestFunction:
    """Create a function environment based on the function name and training data.
        - If `function_name` is a synthetic benchmark function: returns BenchmarkFunction instance
        - If `function_name` is not a synthetic benchmark function but training data is provided: returns IntepolatorFunction instance

    Args:
        function_name: Function name
        sigma: Noise level
        train_x: Optional training inputs, shape [n, dx]
        train_y: Optional training outputs, shape [n, dy]
        train_x_bounds: Input bounds for training data, shape [dx, 2]
        train_y_bounds: Output bounds for training data, shape [dy, 2]

    Returns: Either a BenchmarkFunction or IntepolatorFunction instance.
        Synthetic function found: BenchmarkFunction instance
        Stored GP data found: IntepolatorFunction instance on saved GP data
        Training data provided: IntepolatorFunction instance on training data

    """
    if BenchmarkFunction.get_function_constructor(function_name) is not None:
        return BenchmarkFunction(function_name=function_name, sigma=sigma, dim=dim)
    else:
        return IntepolatorFunction(
            function_name=function_name,
            train_x=train_x,
            train_y=train_y,
            train_x_bounds=train_x_bounds,
            train_y_bounds=train_y_bounds,
            sigma=sigma,
        )
