# =============================================================================
# Bi-Level Benchmark Problems
# =============================================================================

import math

import torch
from torch import Tensor

from problems.base import BiLevelBaseProblem


# Utility functions -----------------------------------------------------------

def log1p(input: Tensor) -> Tensor:
    return input.sign() * torch.log1p(input.abs())


# -----------------------------------------------------------------------------
# SMD01
# -----------------------------------------------------------------------------

class SMD01(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.pi - 1e-2) * l2 - (math.pi - 1e-2) / 2
        return u1, u2, l1, l2

    
    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) + ((u2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD02
# -----------------------------------------------------------------------------

class SMD02(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 6 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.e - 1e-2) * l2 + 1e-2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = -(l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD03
# -----------------------------------------------------------------------------

class SMD03(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.pi - 1e-2) * l2 - (math.pi - 1e-2) / 2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) + ((u2**2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shapr: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = self.q + (l1**2 - torch.cos(2 * math.pi * l1)).sum(dim=-1)
        f3 = ((u2**2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD04
# -----------------------------------------------------------------------------

class SMD04(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 2 * u2 - 1
        l1 = 15 * l1 - 5
        l2 = math.e * l2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = ((torch.abs(u2) - torch.log(1 + l2))**2).sum(dim=-1)
        f1 = (u1**2).sum(dim=-1)
        f2 = -(l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - t1
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shapr: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = self.q + (l1**2 - torch.cos(2 * math.pi * l1)).sum(dim=-1)
        f3 = ((torch.abs(u2) - torch.log(1 + l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD05
# -----------------------------------------------------------------------------

class SMD05(BiLevelBaseProblem):

    num_dims = [2, 3]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 5
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 2,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = 15 * l2 - 5
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_0, l1_1 = l1[..., :-1], l1[..., 1:]
        f1 = (u1**2).sum(dim=-1)
        f2 = -((l1_1 - l1_0**2)**2 + (l1_0 - 1)**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((torch.abs(u2) - l2**2)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_0, l1_1 = l1[..., :-1], l1[..., 1:]
        f1 = (u1**2).sum(dim=-1)
        f2 = ((l1_1 - l1_0**2)**2 + (l1_0 - 1)**2).sum(dim=-1)
        f3 = ((torch.abs(u2) - l2**2)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F

    
    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD06
# -----------------------------------------------------------------------------

class SMD06(BiLevelBaseProblem):

    num_dims = [2, 4]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 6
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        s: int = 2,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + s + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r, self.s = p, q, r, s
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, (self.q + self.s), self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = 15 * l2 - 5
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_q, l1_s = l1[..., :self.q], l1[..., self.q:]
        f1 = (u1**2).sum(dim=-1)
        f2 = -(l1_q**2).sum(dim=-1) + (l1_s**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - l2)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_q = l1[..., :self.q]
        l1_s0, l1_s1 = l1[..., self.q:-1], l1[..., self.q+1:]
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1_q**2).sum(dim=-1) + ((l1_s1 - l1_s0)**2).sum(dim=-1)
        f3 = ((u2 - l2)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F

    
    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD07
# -----------------------------------------------------------------------------

class SMD07(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 6 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.e - 1e-2) * l2 + 1e-2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = torch.prod(torch.cos(u1 / (torch.arange(self.p) + 1).sqrt()))
        f1 = 1 + (u1**2).sum(dim=-1) / 400 - t1
        f2 = -(l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**3).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD08
# -----------------------------------------------------------------------------

class SMD08(BiLevelBaseProblem):

    num_dims = [2, 3]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0)] * 5
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 2,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = 15 * l2 - 5
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_0, l1_1 = l1[..., :-1], l1[..., 1:]
        t1 = torch.exp(-0.2 * ((u1**2).sum(dim=-1) / self.p).sqrt())
        t2 = torch.exp(torch.cos(2 * math.pi * u1).sum(dim=-1) / self.p)
        f1 = 20 + math.e - 20 * t1 - t2
        f2 = -((l1_1 - l1_0**2)**2 + (l1_0 - 1)**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - l2**3)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        l1_0, l1_1 = l1[..., :-1], l1[..., 1:]
        f1 = torch.abs(u1).sum(dim=-1)
        f2 = ((l1_1 - l1_0**2)**2 + (l1_0 - 1)**2).sum(dim=-1)
        f3 = ((u2 - l2**3)**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape[*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# SMD09
# -----------------------------------------------------------------------------

class SMD09(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [1, 1]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 6 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.e + 1 - 1e-2) * l2 - (1 - 1e-2)
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = -(l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - torch.log(1 + l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - torch.log(1 + l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)


    def _upper_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (u1**2).sum(dim=-1) + (u2**2).sum(dim=-1)
        g1 = t1 - (t1 + 0.5).floor()
        return g1


    def _lower_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (l1**2).sum(dim=-1) + (l2**2).sum(dim=-1)
        g1 = t1 - (t1 + 0.5).floor()
        return g1


    def evaluate_slack(
        self,
        X: Tensor,
    ) -> Tensor:  

        G_upper = self._upper_con(X=X)
        G_lower = self._lower_con(X=X)
        return torch.stack([G_upper, G_lower], dim=-1)


# -----------------------------------------------------------------------------
# SMD10
# -----------------------------------------------------------------------------

class SMD10(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [2, 1]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.num_constraints = [(p + r), q]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 15 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.pi - 1e-2) * l2 - (math.pi - 1e-2) / 2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = ((u1 - 2)**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - 2)**2).sum(dim=-1) + ((u2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = ((l1 - 2)**2).sum(dim=-1)
        f3 = ((u2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)


    def _upper_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (u1**3).sum(dim=-1, keepdim=True) - u1**3
        t2 = (u2**3).sum(dim=-1, keepdim=True) - u2**3
        g1 = u1 - t1 - (u2**3).sum(dim=-1, keepdim=True)
        g2 = u2 - t2 - (u1**3).sum(dim=-1, keepdim=True)
        return torch.cat([g1, g2], dim=-1)


    def _lower_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (l1**3).sum(dim=-1, keepdim=True) - l1**3
        g1 = l1 - t1
        return g1


    def evaluate_slack(
        self,
        X: Tensor,
    ) -> Tensor:  

        G_upper = self._upper_con(X=X)
        G_lower = self._lower_con(X=X)
        return torch.cat([G_upper, G_lower], dim=-1)


# -----------------------------------------------------------------------------
# SMD11
# -----------------------------------------------------------------------------

class SMD11(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [1, 1]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.num_constraints = [r, 1]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform


    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 6 * u2 - 5
        l1 = 15 * l1 - 5
        l2 = (math.e - (1 / math.e)) * l2 + (1 / math.e)
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = -(l1**2).sum(dim=-1)
        f3 = (u2**2).sum(dim=-1) - ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - torch.log(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)


    def _upper_con(
        self,
        X: Tensor,
    ) -> Tensor:

        u1, u2, l1, l2 = self._separate(X=X)
        g1 = u2 - (1 / math.sqrt(self.r)) - torch.log(l2)
        return g1


    def _lower_con(
        self,
        X: Tensor,
    ) -> Tensor:

        u1, u2, l1, l2 = self._separate(X=X)
        g1 = ((u2 - torch.log(l2))**2).sum(dim=-1) - 1
        return g1.unsqueeze(-1)


    def evaluate_slack(
        self,
        X: Tensor,
    ) -> Tensor:  

        G_upper = self._upper_con(X=X)
        G_lower = self._lower_con(X=X)
        return torch.cat([G_upper, G_lower], dim=-1)


# -----------------------------------------------------------------------------
# SMD12
# -----------------------------------------------------------------------------

class SMD12(BiLevelBaseProblem):

    num_dims = [2, 2]
    num_objectives = [1, 1]
    num_constraints = [3, 2]
    bounds = [(0.0, 1.0)] * 4
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        p: int = 1,
        q: int = 1,
        r: int = 1,
        log_transform: bool = True,
    ) -> None:

        self.num_dims = [(p + r), (q + r)]
        self.num_constraints = [(r + p + r), (1 + q)]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.p, self.q, self.r = p, q, r
        self.log_transform = log_transform

    
    def _separate(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> tuple[Tensor, ...]:

        dims = [self.p, self.r, self.q, self.r]
        u1, u2, l1, l2 = X.split(dims, dim=-1)
        u1 = 15 * u1 - 5
        u2 = 28.20 * u2 - 14.10
        l1 = 15 * l1 - 5
        l2 = (3.0 - 1e-2) * l1 - (3.0 - 1e-2) / 2
        return u1, u2, l1, l2


    # upper-level objective function
    def _upper_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = torch.tan(l2.abs()).sum(dim=-1)
        t2 = ((u2 - torch.tan(l2))**2).sum(dim=-1)
        f1 = ((u1 - 2)**2).sum(dim=-1)
        f2 = (l1**2).sum(dim=-1)
        f3 = ((u2 - 2)**2).sum(dim=-1) + t1 - t2
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    # lower-level objective function
    def _lower_obj(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        f1 = (u1**2).sum(dim=-1)
        f2 = ((l1 - 2)**2).sum(dim=-1)
        f3 = ((u2 - torch.tan(l2))**2).sum(dim=-1)
        F = log1p(f1 + f2 + f3) if self.log_transform else (f1 + f2 + f3)
        return -F


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper_obj(X=X)
        F_lower = self._lower_obj(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)


    def _upper_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[0]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (u1**3).sum(dim=-1, keepdim=True) - u1**3
        t2 = (u2**3).sum(dim=-1, keepdim=True) - u2**3
        g1 = u2 - torch.tan(l2)
        g2 = u1 - t1 - (u2**3).sum(dim=-1, keepdim=True)
        g3 = u2 - t2 - (u1**3).sum(dim=-1, keepdim=True)
        G = torch.cat([g1, g2, g3], dim=-1)
        G = log1p(G) if self.log_transform else G
        return G

    
    def _lower_con(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_constraints[1]]

        u1, u2, l1, l2 = self._separate(X=X)
        t1 = (l1**3).sum(dim=-1, keepdim=True) - l1**3
        g1 = (((u2 - torch.tan(l2))**2).sum(dim=-1) - 1).unsqueeze(-1)
        g2 = l1 - t1
        G = torch.cat([g1, g2], dim=-1)
        G = log1p(G) if self.log_transform else G
        return G


    def evaluate_slack(
        self,
        X: Tensor,
    ) -> Tensor:  

        G_upper = self._upper_con(X=X)
        G_lower = self._lower_con(X=X)
        return torch.cat([G_upper, G_lower], dim=-1)


