import functools
from typing import Tuple, Callable

import torch

EvalFunction = Callable[[torch.Tensor], torch.Tensor]
Interval = Tuple[torch.Tensor, torch.Tensor]  # (lower, upper)


class Optimizer1D:

    # Batched linspace: lower and upper are (batch_size, 1), returns (batch_size, num_points)
    @staticmethod
    def batched_linspace(lower: torch.Tensor, upper: torch.Tensor, num_points: int) -> torch.Tensor:
        # lower, upper: (batch_size, 1)
        steps = torch.linspace(0, 1, num_points, device=lower.device, dtype=lower.dtype).unsqueeze(0)  # (num_points,)
        # (batch_size, 1) + (batch_size, 1) * (num_points,) -> (batch_size, num_points)
        return lower + (upper - lower) * steps

    @staticmethod
    def _opt_step(func: EvalFunction, x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> Interval:
        batch_size, num_points = x.shape
        y = func(x)
        min_idx = torch.argmin(y, dim=1)
        l_idx = min_idx - 1
        u_idx = min_idx + 1
        l_tmp = x[torch.arange(batch_size), torch.clamp(l_idx, min=0, max=num_points - 1)]
        u_tmp = x[torch.arange(batch_size), torch.clamp(u_idx, min=0, max=num_points - 1)]
        new_lower = torch.where(l_idx < 0, lower.squeeze(1), l_tmp)
        new_upper = torch.where(u_idx >= num_points, upper.squeeze(1), u_tmp)
        return new_lower.unsqueeze(-1), new_upper.unsqueeze(-1)

    @staticmethod
    def optimize(func: EvalFunction,
                 lower: torch.Tensor,
                 upper: torch.Tensor,
                 num_points: int,
                 max_steps: int,
                 x_threshold: float) -> Tuple[torch.Tensor, dict]:

        l, u = lower, upper

        for step in range(max_steps):
            x = Optimizer1D.batched_linspace(l, u, num_points + 2)
            x = x[:, 1:-1]

            l, u = Optimizer1D._opt_step(func, x, l, u)

            if torch.all(torch.abs(l - u) < x_threshold):
                break

        x = (l + u) / 2

        return x, {"optim_steps": step + 1}
