import torch

from .task_spec import TaskSpec


def _as_per_dim_bounds(bounds, dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _hartmann3(X: torch.Tensor) -> torch.Tensor:
    alpha = torch.tensor([1.0, 1.2, 3.0, 3.2], device=X.device, dtype=X.dtype)
    A = torch.tensor(
        [
            [3.0, 10.0, 30.0],
            [0.1, 10.0, 35.0],
            [3.0, 10.0, 30.0],
            [0.1, 10.0, 35.0],
        ],
        device=X.device,
        dtype=X.dtype,
    )
    P = 1e-4 * torch.tensor(
        [
            [3689.0, 1170.0, 2673.0],
            [4699.0, 4387.0, 7470.0],
            [1091.0, 8732.0, 5547.0],
            [381.0, 5743.0, 8828.0],
        ],
        device=X.device,
        dtype=X.dtype,
    )
    diff = X.unsqueeze(-2) - P.unsqueeze(0)
    inner = torch.sum(A.unsqueeze(0) * diff.pow(2), dim=-1)
    return torch.sum(alpha.unsqueeze(0) * torch.exp(-inner), dim=-1)


def _make_objective(scale: float, shift_frac: float, bounds):
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        bd = _as_per_dim_bounds(bounds, dim=3, device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, 3)
        X_shifted = X + float(shift_frac) * widths
        X_shifted = _clamp_with_bounds(X_shifted, bounds)
        return float(scale) * _hartmann3(X_shifted)

    return _objective


def build_history_tasks(dim: int = 3, bounds=(0.0, 1.0)):
    if int(dim) != 3:
        raise ValueError("Hartmann3 is a 3D function; dim must be 3.")
    specs = [
        ("history_1", 0.5, 0.1),
        ("history_2", -0.5, 0.2),
        ("history_3", 1.5, 0.3),
    ]
    return [TaskSpec(name=n, dim=3, bounds=bounds, objective=_make_objective(s, d, bounds=bounds)) for n, s, d in specs]


def build_real_task(dim: int = 3, bounds=(0.0, 1.0)) -> TaskSpec:
    if int(dim) != 3:
        raise ValueError("Hartmann3 is a 3D function; dim must be 3.")
    return TaskSpec(name="target_real", dim=3, bounds=bounds, objective=_make_objective(1.0, 0.0, bounds=bounds))
