import math
from typing import Callable, List

import torch

from .task_spec import TaskSpec


def _as_per_dim_bounds(bounds, dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _branin(X: torch.Tensor) -> torch.Tensor:
    x1 = X[..., 0]
    x2 = X[..., 1]
    t1 = x2 - 5.1 / (4.0 * math.pi**2) * x1**2 + 5.0 / math.pi * x1 - 6.0
    t2 = 10.0 * (1.0 - 1.0 / (8.0 * math.pi)) * torch.cos(x1)
    return -1*(t1**2 + t2 + 10.0)


def _make_objective(scale: float, shift_frac: float, bounds) -> Callable[[torch.Tensor], torch.Tensor]:
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        bd = _as_per_dim_bounds(bounds, dim=2, device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, 2)
        X_shifted = X + float(shift_frac) * widths
        X_shifted = _clamp_with_bounds(X_shifted, bounds)
        return float(scale) * _branin(X_shifted)

    return _objective


def build_history_tasks(dim: int = 2, bounds=((-5.0, 10.0), (0.0, 15.0))) -> List[TaskSpec]:
    if int(dim) != 2:
        raise ValueError("Branin is a 2D function; dim must be 2.")
    specs = [
        ("history_1", 0.5, 0.1),
        ("history_2", -0.5, 0.2),
        ("history_3", 1.5, 0.3),
    ]
    return [TaskSpec(name=n, dim=2, bounds=bounds, objective=_make_objective(s, d, bounds=bounds)) for n, s, d in specs]


def build_real_task(dim: int = 2, bounds=((-5.0, 10.0), (0.0, 15.0))) -> TaskSpec:
    if int(dim) != 2:
        raise ValueError("Branin is a 2D function; dim must be 2.")
    return TaskSpec(name="target_real", dim=2, bounds=bounds, objective=_make_objective(1.0, 0.0, bounds=bounds))
