import math
from typing import Callable, List

import torch

from .task_spec import TaskSpec


def _as_per_dim_bounds(bounds, dim: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _clamp_with_bounds(X: torch.Tensor, bounds) -> torch.Tensor:
    if isinstance(bounds, tuple) and len(bounds) == 2:
        return X.clamp(float(bounds[0]), float(bounds[1]))
    device, dtype = X.device, X.dtype
    bd = _as_per_dim_bounds(bounds, dim=int(X.size(-1)), device=device, dtype=dtype)
    lows, highs = bd[:, 0], bd[:, 1]
    return torch.max(torch.min(X, highs), lows)


def _schwefel(X: torch.Tensor) -> torch.Tensor:
    """
    Schwefel function.
    Standard definition (Minimization):
      f(x) = 418.9829 * d - sum(x_i * sin(sqrt(abs(x_i))))
      Global minimum is 0 at x_i = 420.9687
    
    NOTE: 
    This project (TransferRankBayesOpt) maximizes the objective.
    So we return the NEGATED value.
    Max value is 0 at x_i = 420.9687.
    """
    dim = X.shape[-1]
    val = 418.9829 * dim - torch.sum(X * torch.sin(torch.sqrt(torch.abs(X))), dim=-1)
    return -1.0 * val


def _make_objective(scale: float, shift_frac: float, bounds) -> Callable[[torch.Tensor], torch.Tensor]:
    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        bd = _as_per_dim_bounds(bounds, dim=int(X.shape[-1]), device=device, dtype=dtype)
        widths = (bd[:, 1] - bd[:, 0]).view(1, -1)
        
        # Shift
        X_shifted = X + float(shift_frac) * widths
        # Handle wrap-around or clamping if shifted out of bounds?
        # Branin example uses clamping. Schwefel is periodic-ish but let's stick to clamping
        # or maybe we want to wrap? Standard practice for shift variants usually implies 
        # the function is defined everywhere, but let's clamp to be safe within the bounds 
        # if the intention is to stay within the "box". 
        # However, Schwefel is interesting because of the specific value 420.9687.
        # If we shift, the optimum moves.
        
        # Simple clamp to keep it valid (though Schwefel is defined everywhere, 
        # usually we want to evaluate in the domain)
        # But if we shift the INPUT X, we might push it to a different region.
        # Let's follow Branin's pattern: Shift then Clamp.
        X_shifted = _clamp_with_bounds(X_shifted, bounds)
        
        return float(scale) * _schwefel(X_shifted).unsqueeze(-1) # Ensure [N, 1] output

    return _objective


def build_history_tasks(dim: int = 6, bounds=((-500.0, 500.0),)) -> List[TaskSpec]:
    """
    Create history tasks which are variants of Schwefel function.
    Variations are created by scaling and shifting.
    """
    # If bounds is a single tuple (-500, 500), it applies to all dims.
    # If it is a list of tuples, it must match dim.
    
    # Specs: (name, scale, shift_fraction)
    specs = [
        ("history_1", 0.5, 0.1),
        ("history_2", 1.2, -0.1),
        ("history_3", 0.8, 0.2),
    ]
    
    return [
        TaskSpec(
            name=n, 
            dim=dim, 
            bounds=bounds if len(bounds)==dim else bounds*dim if len(bounds)==1 else bounds, 
            objective=_make_objective(s, d, bounds=bounds if len(bounds)==dim else bounds*dim if len(bounds)==1 else bounds)
        ) 
        for n, s, d in specs
    ]


def build_real_task(dim: int = 6, bounds=((-500.0, 500.0),)) -> TaskSpec:
    """
    Target task is the standard Schwefel (scale=1.0, shift=0.0).
    """
    real_bounds = bounds if len(bounds)==dim else bounds*dim if len(bounds)==1 else bounds
    return TaskSpec(
        name="target_real", 
        dim=dim, 
        bounds=real_bounds, 
        objective=_make_objective(1.0, 0.0, bounds=real_bounds)
    )
