from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Iterable, Protocol, Optional, Union

import numpy as np

try:
    from numba import njit

    NUMBA_AVAILABLE = True
except Exception:
    NUMBA_AVAILABLE = False

    def njit(*args, **kwargs):
        def decorator(func):
            return func

        return decorator


class DivergenceFn(Protocol):
    def __call__(self, p: float, q: float, variance: Optional[float] = None) -> float:
        ...


@njit
def _kl_bernoulli(p: float, q: float, _: float = 0.0) -> float:
    eps = 1e-12
    p = min(max(p, eps), 1.0 - eps)
    q = min(max(q, eps), 1.0 - eps)
    return p * np.log(p / q) + (1.0 - p) * np.log((1.0 - p) / (1.0 - q))


@njit
def _quadratic_proxy(p: float, q: float, variance: Optional[float] = None) -> float:
    if variance is None or variance <= 0.0:
        raise ValueError("variance must be positive for quadratic proxy divergence.")
    return (p - q) ** 2 / (2.0 * variance)


def kl_divergence(
    p: float,
    q: float,
    *,
    mode: str = "bernoulli",
    variance: Optional[float] = None,
) -> float:
    if mode == "bernoulli":
        return float(_kl_bernoulli(p, q))
    if mode in {"gaussian", "quadratic"}:
        return float(_quadratic_proxy(p, q, variance))
    raise ValueError(f"Unsupported divergence mode: {mode}")


def beta_threshold(n: int, delta: float, *, exponent: float = 1.5) -> float:
    if n <= 0:
        raise ValueError("n must be positive.")
    if delta <= 0 or delta >= 1:
        raise ValueError("delta must lie in (0,1).")
    return float(np.log((n ** exponent) / delta))


def detect_change(
    count: int,
    cumulative_rewards: Iterable[float],
    delta: float,
    *,
    divergence: Optional[DivergenceFn] = None,
    variance: Optional[float] = None,
    threshold_fn: Optional[Callable[[int, float], float]] = None,
) -> bool:
    if count <= 2:
        return False
    sums = np.asarray(tuple(cumulative_rewards), dtype=float)
    if sums.size != count:
        raise ValueError("cumulative_rewards must match the supplied count.")
    div = divergence or (lambda p, q, variance=None: kl_divergence(p, q, mode="bernoulli"))
    beta = threshold_fn or (lambda n, d: beta_threshold(n, d))

    for split in range(1, count):
        mu1 = sums[split - 1] / split
        mu2 = (sums[count - 1] - sums[split - 1]) / (count - split)
        mu = sums[count - 1] / count
        stat = split * div(mu1, mu, variance) + (count - split) * div(mu2, mu, variance)
        if stat > beta(count, delta):
            return True
    return False


def _beta_gsr(n: int, delta: float) -> float:
    if n <= 0:
        raise ValueError("n must be positive.")
    if delta <= 0 or delta >= 1:
        raise ValueError("delta must lie in (0,1).")
    return float(( (n ** 2.5)) / delta)


def detect_change_gsr(
    count: int,
    cumulative_rewards: Iterable[float],
    delta: float,
    *,
    divergence: Optional[DivergenceFn] = None,
    variance: Optional[float] = None,
    threshold_fn: Optional[Callable[[int, float], float]] = None,
) -> bool:
    if count <= 2:
        return False
    sums = np.asarray(tuple(cumulative_rewards), dtype=float)
    if sums.size != count:
        raise ValueError("cumulative_rewards must match the supplied count.")
    div = divergence or (lambda p, q, variance=None: kl_divergence(p, q, mode="bernoulli"))
    beta = threshold_fn or _beta_gsr

    change_stat = 0.0
    total_sum = sums[count - 1]
    for split in range(1, count):
        draw1 = split
        draw2 = count - split
        mu1 = sums[split - 1] / draw1
        mu2 = (total_sum - sums[split - 1]) / draw2
        mu = total_sum / count
        stat = draw1 * div(mu1, mu, variance) + draw2 * div(mu2, mu, variance)
        change_stat += float(np.exp(stat))
        if change_stat > beta(count, delta):
            return True
    return False


__all__ = [
    "kl_divergence",
    "beta_threshold",
    "detect_change",
    "detect_change_gsr",
]
