# toy_functions.py
from __future__ import annotations

from typing import Callable, Dict, Tuple

import numpy as np
import torch


# -------------------------
# If you want stochastic objective in canyon_waterfall, set > 0.
# For stable Optuna tuning, keep it at 0.0.
# -------------------------
CANYON_NOISE_SCALE: float = 0.0


# ============================================================
# Torch versions (for autograd / optimization)
# Signature: f(z) where z.shape = (..., 2), returns shape (...).
# ============================================================

def deceptive_landscape_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Loss = exp(-x^2/0.1) * (y-1)^2 + 0.1*x^2 + 10*relu(0.5-|x|)
           + (x^2 + (y+3)^2) * sigmoid(10*(x-1))

    Comments (as provided):
      • Local minimum at (0, 1) - wide valley
      • Plateau at |x| < 0.5
      • Global minimum at (±2, -3)  (note: current formula is not symmetric)
    """
    x = z[..., 0]
    y = z[..., 1]

    term1 = torch.exp(-x**2 / 0.1) * (y - 1.0) ** 2
    term2 = 0.1 * x**2
    term3 = 10.0 * torch.relu(0.5 - torch.abs(x))

    gate = torch.sigmoid(10.0 * (x - 1.0))
    basin = (x**2 + (y + 3.0) ** 2) * gate

    return term1 + term2 + term3 + basin


def canyon_waterfall_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Fast descent along a steep canyon, then a sharp "waterfall"
    (sharp gradient change), after which very small steps are needed.

    Note: there is an optional stochastic term controlled by CANYON_NOISE_SCALE.
    Keep it 0.0 for deterministic tuning.
    """
    x = z[..., 0]
    y = z[..., 1]

    canyon = 0.01 * x**2 + (y - torch.sin(x)) ** 2
    waterfall = 10.0 * torch.sigmoid(100.0 * (x - 3.0)) * (y + 5.0) ** 2

    if CANYON_NOISE_SCALE > 0.0:
        noise = CANYON_NOISE_SCALE * torch.randn_like(x) * torch.exp(-(x - 5.0) ** 2)
    else:
        noise = torch.zeros_like(x)

    return canyon + waterfall + noise


def concentric_barriers_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Series of concentric barriers around the global minimum (at the center).
    """
    x = z[..., 0]
    y = z[..., 1]
    r = torch.sqrt(x**2 + y**2)

    target = r**2  # global min at (0,0)
    barriers = torch.zeros_like(r)
    for i in range(1, 6):
        strength = 2.0 / float(i)
        barriers = barriers + strength * torch.exp(-((r - float(i)) ** 2) / 0.05)

    return target + barriers


def local_min_plateau_deep_min_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Sharp descent into local minimum -> climb -> large plateau ->
    descent into sharp deep minimum.
    """
    x = z[..., 0]
    y = z[..., 1]

    local_min = -5.0 * torch.exp(-((x + 2.0) ** 2 + (y + 2.0) ** 2) / 0.05)
    climb = 3.0 * torch.exp(-((x + 1.0) ** 2 + (y + 1.0) ** 2) / 0.2)
    plateau = 2.0 * (1.0 - torch.tanh((x**2 + y**2) / 0.3))
    deep_min = -8.0 * torch.exp(-((x - 2.0) ** 2 + (y - 2.0) ** 2) / 0.01)
    slope = 0.05 * (x**2 + y**2)

    return local_min + climb + plateau + deep_min + slope


def plateau_with_traps_deep_min_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Descent -> large plateau with traps -> descent into sharp deep minimum.
    """
    x = z[..., 0]
    y = z[..., 1]

    descent = 0.1 * (x**2 + y**2)
    plateau = 2.0 * torch.exp(-(x**2 + y**2) / 0.5)

    trap1 = -2.0 * torch.exp(-((x - 1.0) ** 2 + (y - 0.5) ** 2) / 0.05)
    trap2 = -2.0 * torch.exp(-((x + 0.5) ** 2 + (y - 1.0) ** 2) / 0.05)
    trap3 = -2.0 * torch.exp(-((x - 0.5) ** 2 + (y + 1.0) ** 2) / 0.05)

    deep_min = -10.0 * torch.exp(-((x - 2.5) ** 2 + (y - 2.5) ** 2) / 0.005)
    guidance = 0.01 * (x + y)

    return descent + plateau + trap1 + trap2 + trap3 + deep_min + guidance


def complex_journey_torch(z: torch.Tensor) -> torch.Tensor:
    """
    Scenario 4:
    Descent -> plateau with local minima -> wide descent ->
    sharp minimum -> climb -> wide minimum (+ small deterministic noise)
    """
    x = z[..., 0]
    y = z[..., 1]

    descent = 0.2 * (x**2 + y**2)

    plateau = 3.0 * torch.exp(-(x**2 + y**2) / 0.8)
    trap1 = -1.5 * torch.exp(-((x - 1.0) ** 2 + y**2) / 0.1)
    trap2 = -1.5 * torch.exp(-(x**2 + (y - 1.0) ** 2) / 0.1)

    wide_descent = 0.5 * ((x - 2.5) ** 2 + (y - 2.5) ** 2) * torch.exp(
        -((x - 1.5) ** 2 + (y - 1.5) ** 2) / 0.3
    )

    sharp_min = -12.0 * torch.exp(-((x - 3.0) ** 2 + (y - 3.0) ** 2) / 0.005)
    bump = 8.0 * torch.exp(-((x - 3.5) ** 2 + (y - 3.5) ** 2) / 0.2)
    wide_min = -6.0 * torch.exp(-((x - 4.0) ** 2 + (y - 4.0) ** 2) / 0.5)

    base = descent + plateau + trap1 + trap2 + wide_descent + sharp_min + bump + wide_min
    noise = 0.01 * torch.sin(5.0 * x) * torch.cos(5.0 * y)  # deterministic "texture"
    return base + noise


TORCH_FUNCTIONS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
    "deceptive_landscape": deceptive_landscape_torch,
    "canyon_waterfall": canyon_waterfall_torch,
    "concentric_barriers": concentric_barriers_torch,
    "local_min_plateau_deep_min": local_min_plateau_deep_min_torch,
    "plateau_with_traps_deep_min": plateau_with_traps_deep_min_torch,
    "complex_journey": complex_journey_torch,
}


# ============================================================
# BAD_CENTERS: centers for "hard region + jitter" tuning
# Using your provided "start:" points as defaults.
# ============================================================

BAD_CENTERS: Dict[str, Tuple[float, float]] = {
    "deceptive_landscape": (-2.5, 2.5),
    "canyon_waterfall": (2.8, -0.5),
    "concentric_barriers": (4.5, 0.0),
    "local_min_plateau_deep_min": (-2.0, -2.0),
    "plateau_with_traps_deep_min": (0.8, 0.4),
    "complex_journey": (0.5, 1.2),
}


# ============================================================
# Numpy versions (optional; useful for plotting contour maps)
# Signature: f(x, y) -> scalar/array
# Deterministic (no RNG).
# ============================================================

def deceptive_landscape_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    term1 = np.exp(-x**2 / 0.1) * (y - 1.0) ** 2
    term2 = 0.1 * x**2
    term3 = 10.0 * np.maximum(0.0, 0.5 - np.abs(x))
    gate = 1.0 / (1.0 + np.exp(-10.0 * (x - 1.0)))
    basin = (x**2 + (y + 3.0) ** 2) * gate
    return term1 + term2 + term3 + basin


def canyon_waterfall_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    canyon = 0.01 * x**2 + (y - np.sin(x)) ** 2
    waterfall = 10.0 * (1.0 / (1.0 + np.exp(-100.0 * (x - 3.0)))) * (y + 5.0) ** 2
    return canyon + waterfall  # deterministic (no noise)


def concentric_barriers_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    r = np.sqrt(x**2 + y**2)
    target = r**2
    barriers = 0.0
    for i in range(1, 6):
        strength = 2.0 / float(i)
        barriers = barriers + strength * np.exp(-((r - float(i)) ** 2) / 0.05)
    return target + barriers


def local_min_plateau_deep_min_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    local_min = -5.0 * np.exp(-((x + 2.0) ** 2 + (y + 2.0) ** 2) / 0.05)
    climb = 3.0 * np.exp(-((x + 1.0) ** 2 + (y + 1.0) ** 2) / 0.2)
    plateau = 2.0 * (1.0 - np.tanh((x**2 + y**2) / 0.3))
    deep_min = -8.0 * np.exp(-((x - 2.0) ** 2 + (y - 2.0) ** 2) / 0.01)
    slope = 0.05 * (x**2 + y**2)
    return local_min + climb + plateau + deep_min + slope


def plateau_with_traps_deep_min_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    descent = 0.1 * (x**2 + y**2)
    plateau = 2.0 * np.exp(-(x**2 + y**2) / 0.5)
    trap1 = -2.0 * np.exp(-((x - 1.0) ** 2 + (y - 0.5) ** 2) / 0.05)
    trap2 = -2.0 * np.exp(-((x + 0.5) ** 2 + (y - 1.0) ** 2) / 0.05)
    trap3 = -2.0 * np.exp(-((x - 0.5) ** 2 + (y + 1.0) ** 2) / 0.05)
    deep_min = -10.0 * np.exp(-((x - 2.5) ** 2 + (y - 2.5) ** 2) / 0.005)
    guidance = 0.01 * (x + y)
    return descent + plateau + trap1 + trap2 + trap3 + deep_min + guidance


def complex_journey_np(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    descent = 0.2 * (x**2 + y**2)
    plateau = 3.0 * np.exp(-(x**2 + y**2) / 0.8)
    trap1 = -1.5 * np.exp(-((x - 1.0) ** 2 + y**2) / 0.1)
    trap2 = -1.5 * np.exp(-(x**2 + (y - 1.0) ** 2) / 0.1)

    wide_descent = 0.5 * ((x - 2.5) ** 2 + (y - 2.5) ** 2) * np.exp(
        -((x - 1.5) ** 2 + (y - 1.5) ** 2) / 0.3
    )

    sharp_min = -12.0 * np.exp(-((x - 3.0) ** 2 + (y - 3.0) ** 2) / 0.005)
    bump = 8.0 * np.exp(-((x - 3.5) ** 2 + (y - 3.5) ** 2) / 0.2)
    wide_min = -6.0 * np.exp(-((x - 4.0) ** 2 + (y - 4.0) ** 2) / 0.5)

    base = descent + plateau + trap1 + trap2 + wide_descent + sharp_min + bump + wide_min
    noise = 0.01 * np.sin(5.0 * x) * np.cos(5.0 * y)
    return base + noise


NP_FUNCTIONS: Dict[str, Callable[[np.ndarray, np.ndarray], np.ndarray]] = {
    "deceptive_landscape": deceptive_landscape_np,
    "canyon_waterfall": canyon_waterfall_np,
    "concentric_barriers": concentric_barriers_np,
    "local_min_plateau_deep_min": local_min_plateau_deep_min_np,
    "plateau_with_traps_deep_min": plateau_with_traps_deep_min_np,
    "complex_journey": complex_journey_np,
}
