from dataclasses import dataclass
from math import ceil
from typing import Callable, Literal, Optional, cast


@dataclass(frozen=True, kw_only=True)
class Step:
    step: int
    lo: int
    hi: int
    mid: int


@dataclass(frozen=True, kw_only=True)
class SearchMetrics:
    steps: list[Step]

    @property
    def total_steps(self) -> int:
        return self.steps[-1].step

    def __repr__(self):
        steps = [f"  {step.__repr__()}" for step in self.steps]
        return "dict(steps=[\n" + ",\n".join(steps) + "\n])"


Lo = int
Hi = int


# Use ceiling rather than truncation
# because that's what we do for our CDF.
# i.e. for CDF search, we use the first whole
# number with >= 50% CDF. The analog here is
# using the first whole number with >= 50%
# whole numbers to the left.
def default_mid(lo: int, hi: int):
    return ceil((hi + lo) / 2)


def default_sign(target: float):
    def sign(val: float) -> Literal[-1, 1]:
        return -1 if val <= target else 1

    return sign


SignFunc = Callable[[float], Literal[-1, 1]]


def binary_search(
    lo: int,
    hi: int,
    target: Optional[int] = None,
    epsilon: int = 1,  # bracket size
    mid_func: Callable[[int, int], int] = default_mid,
    sign: Optional[SignFunc] = None,
) -> SearchMetrics:
    """
    Assumes the target is within [lo, hi]
    And sign(lo) == -1 and sign(hi) == 1

    `mid_func` must return an integer `x` such that `lo <= x <= hi`

    `sign` takes precedence over `target`
    """

    if not isinstance(target, int) and not sign:
        raise Exception("Either a target or sign function must be given.")

    if isinstance(target, int) and not sign:
        sign = default_sign(target)
    else:
        sign = cast(SignFunc, sign)

    if epsilon < 1:
        raise Exception("ε must be >= 1")

    steps = -1  # First step is 0th step because we haven't probed yet.
    metrics = SearchMetrics(steps=[])

    last_mid = None
    while lo <= hi:
        steps += 1

        if hi - lo <= epsilon:
            last_step = Step(step=steps, lo=lo, hi=hi, mid=cast(int, last_mid))
            metrics.steps.append(last_step)
            return metrics

        m = mid_func(lo, hi)

        # NOTE: This can happen e.g.
        # norm(loc=5, scale=1.15), loc=0,hi=5
        # m == 5
        # => hi - lo + 1 == 6 (> 2)
        if m == hi:
            m -= 1

        s = sign(m)
        metrics.steps.append(Step(step=steps, lo=lo, hi=hi, mid=m))

        if s == 1:
            hi = m
        else:
            lo = m
        last_mid = m

    raise Exception()
