# =============================================================================
# Bi-Level Synthetic Problems
# =============================================================================

import math
from typing import Literal

import torch
from torch import Tensor

from problems.base import BiLevelBaseProblem
from utils import RFFHybridModel



# Utility functions -----------------------------------------------------------

def log1p(input: Tensor) -> Tensor:
    return input.sign() * torch.log1p(input.abs())


# -----------------------------------------------------------------------------
# GPPriorSamples
# -----------------------------------------------------------------------------

class GPPriorSamples(BiLevelBaseProblem):

    num_dims = [1, 1]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0), (0.0, 1.0)]
    candidates_path = None

    def __init__(
        self,
        noise_std: float | list[float] | None = None,
        has_candidates: bool = False,
        num_discretize: int | list[int] | None = None,
        num_dims: list[int] | None = None,
        num_features: int = 1000,
        lscale_upper: float = 0.15,
        oscale_upper: float = 1.0,
        lscale_lower: float = 0.15,
        oscale_lower: float = 1.0,
        seed_upper: int = 1,
        seed_lower: int = 3,
    ) -> None:

        self.num_dims = num_dims if num_dims is not None else [1, 1]
        self.bounds = [(0.0, 1.0)] * sum(self.num_dims)
        super().__init__(noise_std, has_candidates, num_discretize)
        self.num_features = num_features
        self.lscale_upper = torch.tensor([lscale_upper], dtype=torch.double)
        self.oscale_upper = torch.tensor([oscale_upper], dtype=torch.double)
        self.lscale_lower = torch.tensor([lscale_lower], dtype=torch.double)
        self.oscale_lower = torch.tensor([oscale_lower], dtype=torch.double)
        
        gen = torch.Generator(); gen.manual_seed(seed_upper)
        z1 = torch.randn(self.num_features, self.d_in, generator=gen)
        z2 = torch.rand(self.num_features, generator=gen)
        self.o_upper = z1 / self.lscale_upper
        self.b_upper = z2 * 2 * math.pi
        self.w_upper = torch.randn(self.num_features, generator=gen)
        gen = torch.Generator(); gen.manual_seed(seed_lower)
        z1 = torch.randn(self.num_features, self.d_in, generator=gen)
        z2 = torch.rand(self.num_features, generator=gen)
        self.o_lower = z1 / self.lscale_lower
        self.b_lower = z2 * 2 * math.pi
        self.w_lower = torch.randn(self.num_features, generator=gen)
        

    # upper-level objective function
    def _upper(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        norm = torch.sqrt(2 * self.oscale_upper / self.num_features)
        phi_X = norm * torch.cos(X @ self.o_upper.T + self.b_upper)
        sample = (phi_X @ self.w_upper)
        return sample


    def _lower(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        norm = torch.sqrt(2 * self.oscale_lower / self.num_features)
        phi_X = norm * torch.cos(X @ self.o_lower.T + self.b_lower)
        sample = (phi_X @ self.w_lower)
        return sample


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._upper(X=X)
        F_lower = self._lower(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# BraninGoldstein (UL: Branin, LL: Goldstein-Price)
# -----------------------------------------------------------------------------

class BraninGoldstein(BiLevelBaseProblem):

    num_dims = [1, 1]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0), (0.0, 1.0)]
    candidates_path = None

    # upper-level objective function
    def _rescaled_branin(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        x_0 = 15 * X[..., 0] - 5
        x_1 = 15 * X[..., 1]
        
        t1 = (
            x_1 - (5.1 * x_0 ** 2) / (4 * math.pi ** 2)
            + (5.0 * x_0) / math.pi - 6
        )
        t2 = 10 * (1 - 1 / (8 * math.pi)) * torch.cos(x_0)
        return -(t1 ** 2 + t2 - 44.81) / 51.95


    # lower-level objective function
    def _rescaled_goldstein_price(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        x_0 = 4 * X[..., 0] - 2
        x_1 = 4 * X[..., 1] - 2

        t1 = (x_0 + x_1 + 1) ** 2
        t2 = (
            3 * x_0 ** 2 + 3 * x_1 ** 2 + 6 * x_0 * x_1
            - 14 * x_0 - 14 * x_1 + 19
        )
        t3 = (2 * x_0 - 3 * x_1) ** 2
        t4 = (
            12 * x_0 ** 2 + 27 * x_1 ** 2 - 36 * x_0 * x_1
            - 32 * x_0 + 48 * x_1 + 18
        )
        return -(torch.log((1 + t1 * t2) * (30 + t3 * t4)) - 8.693) / 2.427


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]

        F_upper = self._rescaled_branin(X=X)
        F_lower = self._rescaled_goldstein_price(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# CamelBranin (UL: Six-Hump Camel, LL: Branin)
# -----------------------------------------------------------------------------

class CamelBranin(BiLevelBaseProblem):

    num_dims = [1, 1]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0), (0.0, 1.0)]
    candidates_path = None

    # upper-level objective function
    def _rescaled_sixhump_camel(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        x_0 = 6 * X[..., 0] - 3
        x_1 = 4 * X[..., 1] - 2

        t1 = (4 - 2.1 * x_0 ** 2 + (x_0 ** 4) / 3) * x_0 ** 2
        t2 = x_0 * x_1
        t3 = (-4 + 4 * x_1 ** 2) * x_1 ** 2
        return -log1p(t1 + t2 + t3)


    # lower-level objective function
    def _rescaled_branin(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        x_0 = 15 * X[..., 0] - 5
        x_1 = 15 * X[..., 1]
        
        t1 = (
            x_1 - (5.1 * x_0 ** 2) / (4 * math.pi ** 2)
            + (5.0 * x_0) / math.pi - 6
        )
        t2 = 10 * (1 - 1 / (8 * math.pi)) * torch.cos(x_0)
        return -(t1 ** 2 + t2 - 44.81) / 51.95


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]
        
        F_upper = self._rescaled_sixhump_camel(X=X)
        F_lower = self._rescaled_branin(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)



# -----------------------------------------------------------------------------
# DixonBranin (UL: Dixon-Price, LL: Branin)
# -----------------------------------------------------------------------------

class DixonBranin(BiLevelBaseProblem):

    num_dims = [1, 1]
    num_objectives = [1, 1]
    num_constraints = [0, 0]
    bounds = [(0.0, 1.0), (0.0, 1.0)]
    candidates_path = None

    # upper-level objective function
    def _rescaled_dixon_price(
        self,
        X: Tensor,    # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[0]]

        x_0 = 20 * X[..., 0] - 10
        x_1 = 20 * X[..., 1] - 10

        t1 = (x_0 - 1) ** 2
        t2 = 2 * (2 * x_1 ** 2 - x_0) ** 2
        return -log1p(t1 + t2)


    # lower-level objective function
    def _rescaled_branin(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, num_objectives[1]]

        x_0 = 15 * X[..., 0] - 5
        x_1 = 15 * X[..., 1]
        
        t1 = (
            x_1 - (5.1 * x_0 ** 2) / (4 * math.pi ** 2)
            + (5.0 * x_0) / math.pi - 6
        )
        t2 = 10 * (1 - 1 / (8 * math.pi)) * torch.cos(x_0)
        return -(t1 ** 2 + t2 - 44.81) / 51.95


    def evaluate_true(
        self,
        X: Tensor,  # shape: [*batch_shape, d_in]
    ) -> Tensor:  # shape: [*batch_shape, sum(num_objectives)]
        
        F_upper = self._rescaled_dixon_price(X=X)
        F_lower = self._rescaled_branin(X=X)
        return torch.stack([F_upper, F_lower], dim=-1)

