import torch
import math
from typing import Callable, List
from .task_spec import TaskSpec


def _as_per_dim_bounds(bounds, dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _ackley(X: torch.Tensor) -> torch.Tensor:
    """
    Ackley function.
    Standard definition (Minimization):
      f(x) = -a * exp(-b * sqrt(1/d * sum(x^2))) - exp(1/d * sum(cos(c*x))) + a + exp(1)
    Global minimum is 0 at x = (0, ..., 0)
    Typically evaluated on [-32.768, 32.768] or similar.
    
    NOTE: 
    This project (TransferRankBayesOpt) maximizes the objective.
    So we return the NEGATED value.
    Max value is 0 at x = (0, ..., 0).
    """
    a = 20.0
    b = 0.2
    c = 2.0 * math.pi
    d = float(X.shape[-1])
    
    sum_sq = torch.sum(X**2, dim=-1)
    term1 = -a * torch.exp(-b * torch.sqrt(sum_sq / d))
    
    sum_cos = torch.sum(torch.cos(c * X), dim=-1)
    term2 = -torch.exp(sum_cos / d)
    
    val = term1 + term2 + a + math.exp(1.0)
    return -1.0 * val


def _make_objective(scale: float, shift_frac: float, bounds) -> Callable[[torch.Tensor], torch.Tensor]:
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        dim = 50
        bd = _as_per_dim_bounds(bounds, dim=dim, device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, dim)
        
        # Shift
        X_shifted = X + float(shift_frac) * widths
        X_shifted = _clamp_with_bounds(X_shifted, bounds)
        
        return float(scale) * _ackley(X_shifted).unsqueeze(-1)

    return _objective


def build_history_tasks(dim: int = 50, bounds=((-32.768, 32.768),)) -> List[TaskSpec]:
    if int(dim) != 50:
        raise ValueError("Ackley50 is a 50D function; dim must be 50.")
        
    # Specs: (name, scale, shift_fraction)
    # Similar to Branin/Schwefel logic
    specs = [
        ("history_1", -0.5, 0.025),
        ("history_2", 0.5, 0.05)
    ]
    
    return [
        TaskSpec(
            name=n, 
            dim=50, 
            bounds=bounds if len(bounds)==50 else bounds*50 if len(bounds)==1 else bounds, 
            objective=_make_objective(s, d, bounds=bounds if len(bounds)==50 else bounds*50 if len(bounds)==1 else bounds)
        ) 
        for n, s, d in specs
    ]

def build_real_task(dim: int = 50, bounds=((-32.768, 32.768),)) -> TaskSpec:
    """
    Target task is the standard Ackley (scale=1.0, shift=0.0).
    """
    if int(dim) != 50:
        raise ValueError("Ackley50 is a 50D function; dim must be 50.")
        
    real_bounds = bounds if len(bounds)==50 else bounds*50 if len(bounds)==1 else bounds
    return TaskSpec(
        name="target_real", 
        dim=50, 
        bounds=real_bounds, 
        objective=_make_objective(1.0, 0.0, bounds=real_bounds)
    )
