import torch
import math
from typing import Callable, List
from .task_spec import TaskSpec


def _as_per_dim_bounds(
    bounds, dim: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _ackley(X: torch.Tensor) -> torch.Tensor:
    """
    Ackley function.
    Standard definition (Minimization):
      f(x) = -a * exp(-b * sqrt(1/d * sum(x^2))) - exp(1/d * sum(cos(c*x))) + a + exp(1)
    Global minimum is 0 at x = (0, ..., 0)
    Typically evaluated on [-32.768, 32.768] or similar.

    NOTE:
    This project (TransferRankBayesOpt) maximizes the objective.
    So we return the NEGATED value.
    Max value is 0 at x = (0, ..., 0).
    """
    a = 20.0
    b = 0.2
    c = 2.0 * math.pi
    d = float(X.shape[-1])

    sum_sq = torch.sum(X**2, dim=-1)
    term1 = -a * torch.exp(-b * torch.sqrt(sum_sq / d))

    sum_cos = torch.sum(torch.cos(c * X), dim=-1)
    term2 = -torch.exp(sum_cos / d)

    val = term1 + term2 + a + math.exp(1.0)
    return -1.0 * val


def _make_objective(
    scale: float, shift_frac: float, bounds
) -> Callable[[torch.Tensor], torch.Tensor]:
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        bd = _as_per_dim_bounds(bounds, dim=2, device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, 2)

        # Shift
        X_shifted = X + float(shift_frac) * widths
        X_shifted = _clamp_with_bounds(X_shifted, bounds)

        return float(scale) * _ackley(X_shifted).unsqueeze(-1)

    return _objective


def build_history_tasks(dim: int = 2, bounds=((-32.768, 32.768),)) -> List[TaskSpec]:
    if int(dim) != 2:
        raise ValueError("Ackley2 is a 2D function; dim must be 2.")

    # Specs: (name, scale, shift_fraction)
    # Similar to Branin/Schwefel logic
    specs = [
        ("history_1", 0.5, 0.1),
        ("history_2", 1.2, -0.2),
        ("history_3", 0.8, 0.3),
    ]

    return [
        TaskSpec(
            name=n,
            dim=2,
            bounds=bounds
            if len(bounds) == 2
            else bounds * 2
            if len(bounds) == 1
            else bounds,
            objective=_make_objective(
                s,
                d,
                bounds=bounds
                if len(bounds) == 2
                else bounds * 2
                if len(bounds) == 1
                else bounds,
            ),
        )
        for n, s, d in specs
    ]


def build_real_task(dim: int = 2, bounds=((-32.768, 32.768),)) -> TaskSpec:
    """
    Target task is the standard Ackley (scale=1.0, shift=0.0).
    """
    if int(dim) != 2:
        raise ValueError("Ackley2 is a 2D function; dim must be 2.")

    real_bounds = (
        bounds if len(bounds) == 2 else bounds * 2 if len(bounds) == 1 else bounds
    )
    return TaskSpec(
        name="target_real",
        dim=2,
        bounds=real_bounds,
        objective=_make_objective(1.0, 0.0, bounds=real_bounds),
    )
