import torch

from .task_spec import TaskSpec


def _as_per_dim_bounds(bounds, dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _hartmann6(X: torch.Tensor) -> torch.Tensor:
    alpha = torch.tensor([1.0, 1.2, 3.0, 3.2], device=X.device, dtype=X.dtype)
    A = torch.tensor(
        [
            [10.0, 3.0, 17.0, 3.5, 1.7, 8.0],
            [0.05, 10.0, 17.0, 0.1, 8.0, 14.0],
            [3.0, 3.5, 1.7, 10.0, 17.0, 8.0],
            [17.0, 8.0, 0.05, 10.0, 0.1, 14.0],
        ],
        device=X.device,
        dtype=X.dtype,
    )
    P = 1e-4 * torch.tensor(
        [
            [1312.0, 1696.0, 5569.0, 124.0, 8283.0, 5886.0],
            [2329.0, 4135.0, 8307.0, 3736.0, 1004.0, 9991.0],
            [2348.0, 1451.0, 3522.0, 2883.0, 3047.0, 6650.0],
            [4047.0, 8828.0, 8732.0, 5743.0, 1091.0, 381.0],
        ],
        device=X.device,
        dtype=X.dtype,
    )
    diff = X.unsqueeze(-2) - P.unsqueeze(0)
    inner = torch.sum(A.unsqueeze(0) * diff.pow(2), dim=-1)
    return torch.sum(alpha.unsqueeze(0) * torch.exp(-inner), dim=-1)


def _make_objective(scale: float, shift_frac: float, bounds):
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        bd = _as_per_dim_bounds(bounds, dim=6, device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, 6)
        X_shifted = X + float(shift_frac) * widths
        X_shifted = _clamp_with_bounds(X_shifted, bounds)
        return float(scale) * _hartmann6(X_shifted)

    return _objective


def build_history_tasks(dim: int = 6, bounds=(0.0, 1.0)):
    if int(dim) != 6:
        raise ValueError("Hartmann6 is a 6D function; dim must be 6.")
    specs = [
        ("history_1", 0.5, 0.1),
        ("history_2", -0.5, 0.2),
        ("history_3", 1.5, 0.3),
    ]
    return [TaskSpec(name=n, dim=6, bounds=bounds, objective=_make_objective(s, d, bounds=bounds)) for n, s, d in specs]


def build_real_task(dim: int = 6, bounds=(0.0, 1.0)) -> TaskSpec:
    if int(dim) != 6:
        raise ValueError("Hartmann6 is a 6D function; dim must be 6.")
    return TaskSpec(name="target_real", dim=6, bounds=bounds, objective=_make_objective(1.0, 0.0, bounds=bounds))
