from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Sequence, Protocol
import math
import numpy as np

class Policy(Protocol):
    def predict(self, x: Any) -> int: ...

class BanditAlgorithm: 
    def __init__(self, num_actions: int, horizon: int):
        self.num_actions = int(num_actions)
        self.T = int(horizon)
        self.t = 0
        self.is_reset = False
        self.ChangePoints: List[int] = []
    def select_arm(self, arms, context=None): raise NotImplementedError
    def update_statistics(self, arm, reward): raise NotImplementedError
    def update(self, arm, reward):
        self.update_statistics(arm, reward)
        if self.is_reset: self.is_reset = False
        else: self.t += 1
    def re_init(self): self.t = 0; self.is_reset = True

@dataclass
class _Slice:
    base: BanditAlgorithm
    s: int
    e: int
    length: int
    L: int = 0
    sum_R: float = 0.0
    S_pi: Optional[np.ndarray] = None 
    def ensure_arrays(self, pi_size: int):
        if self.S_pi is None:
            self.S_pi = np.zeros(pi_size, dtype=np.float64)

class MASTERContext(BanditAlgorithm):
    def __init__(
        self,
        num_actions: int,
        horizon: int,
        *,
        delta: float,
        base_factory: Callable[[], BanditAlgorithm],
        policies: Sequence[Policy],
        c2: float = 2.0,
        choose_rule: str = "minlen",
        seed: Optional[int] = None,
        n_init: int = 0,
    ):
        super().__init__(num_actions, horizon)
        if len(policies) == 0:
            raise ValueError("Provide a non-empty finite policy bank `policies`.")
        self.A = int(num_actions)
        self.T = int(horizon)
        self.delta = float(delta)
        self.base_factory = base_factory
        self.policies: List[Policy] = list(policies)
        self.Pi = len(self.policies)
        self.c2 = float(c2)
        self.choose_rule = choose_rule
        self.rng = np.random.default_rng(seed)

        self.U_min: float = float("+inf")
        self.g_minus_r_sum: float = 0.0
        self.epoch_len: int = 0

        self.tn = self.t
        self.n = max(0, int(n_init))
        self.active: List[_Slice] = []
        self.alg_index: int = 0

        self._last_context: Any = None
        self._last_arms: Optional[Sequence[Any]] = None
        self._last_choice: Optional[int] = None
        self._last_pmf: Optional[np.ndarray] = None
        self._last_p: float = 1.0
        self._last_gtilde: float = 1.0

        self.c0 = math.log(max(1.0, self.Pi * self.T / max(self.delta, 1e-300)))  
        self.cT = math.log(max(1.0, self.T / max(self.delta, 1e-300)))            
        self.n_hat = math.log2(max(2, self.T)) + 1.0

        self.name = f"MASTER (|Π|={self.Pi})"
        self.init_params = dict(
            num_actions=num_actions, horizon=horizon, delta=delta,
            base_factory=base_factory, policies=self.policies,
            c2=c2, choose_rule=choose_rule, seed=seed, n_init=n_init
        )

        self._reset_epoch_state()
        self._spawn_at_tau(0)

    def select_arm(self, arms: Sequence[Any], context: Any = None) -> int:
        self._last_arms = arms
        self._last_context = context
        tau = self.t - self.tn
        self._spawn_at_tau(tau)
        self._prune_finished(tau)

        if not self.active:
            N = 2 ** self.n
            self.active.append(_Slice(self.base_factory(), 0, N - 1, N))

        self.alg_index = self._choose_active_index()
        sl = self.active[self.alg_index]

        a, pmf_smoothed = self._sample_with_mu(sl.base, arms, context)
        self._last_choice = a
        self._last_pmf = pmf_smoothed
        self._last_p = float(pmf_smoothed[a])

        g = self._compute_f_tilde(sl)
        self._last_gtilde = g
        self.U_min = min(self.U_min, g)

        return int(a)

    def update_statistics(self, arm: int, reward: float) -> None:
        a = int(arm)
        r = float(np.clip(reward, 0.0, 1.0))
        ctx = self._last_context

        active = self.active[self.alg_index]
        if hasattr(active.base, "_last_pmf") and self._last_pmf is not None:
            try:
                active.base._last_pmf = np.asarray(self._last_pmf, dtype=float)
            except Exception:
                pass
        active.base.update_statistics(a, r)

        preds = np.fromiter((pi.predict(ctx) for pi in self.policies), dtype=np.int32, count=self.Pi)
        contrib_vec = (preds == a).astype(np.float64) * (r / max(self._last_p, 1e-12))

        for sl in self.active:
            sl.ensure_arrays(self.Pi)
            sl.S_pi += contrib_vec
            sl.sum_R += r
            sl.L += 1

        self.epoch_len += 1
        self.g_minus_r_sum += (self._last_gtilde - r)

        if (not self._test1()) or (not self._test2()):
            self.ChangePoints.append(self.t)
            self._restart_master()
            return

        if (self.t + 1) >= (self.tn + 2 ** self.n):
            self.n += 1
            self._restart_epoch(keep_n=True)

    def re_init(self):
        super().re_init()
        self.ChangePoints = []
        self._restart_master()

    def _restart_master(self):
        self.n = 0
        self._restart_epoch(keep_n=False)

    def _restart_epoch(self, keep_n: bool):
        self._reset_epoch_state()
        self._spawn_at_tau(0)

    def _reset_epoch_state(self):
        for sl in getattr(self, "active", []):
            self._finish_base(sl.base)
        self.tn = self.t
        self.active = []
        self.alg_index = 0
        self.U_min = float("+inf")
        self.g_minus_r_sum = 0.0
        self.epoch_len = 0
        self._last_context = None
        self._last_arms = None
        self._last_choice = None
        self._last_pmf = None
        self._last_p = 1.0
        self._last_gtilde = 1.0

    def _spawn_at_tau(self, tau: int):
        N = 2 ** self.n
        for i in range(self.n + 1):
            m = self.n - i
            block = 2 ** m
            if tau % block != 0:
                continue
            pr = self._rho(N) / max(self._rho(block), 1e-12)
            if self._bernoulli(pr):
                s = tau
                e = tau + block - 1
                self.active.append(_Slice(self.base_factory(), s, e, block))
        if tau == 0 and not any(sl.s == 0 and sl.length == N for sl in self.active):
            self.active.append(_Slice(self.base_factory(), 0, N - 1, N))

    def _prune_finished(self, tau: int):
        if not self.active:
            return
        keep: List[_Slice] = []
        for sl in self.active:
            if sl.e >= tau:
                keep.append(sl)
            else:
                self._finish_base(sl.base)
        self.active = keep

    def _finish_base(self, base: BanditAlgorithm):
        for meth in ("finish", "reset", "_vw"):
            if hasattr(base, meth):
                try:
                    if meth == "finish": base.finish()
                    elif meth == "reset": base.reset()
                    elif meth == "_vw" and getattr(base, "_vw", None) is not None:
                        base._vw.finish(); base._vw = None 
                except Exception:
                    pass

    def _choose_active_index(self) -> int:
        if self.choose_rule == "uniform_m":
            lengths = sorted(set(sl.length for sl in self.active))
            L_choice = int(self.rng.choice(lengths))
            cand = [i for i, sl in enumerate(self.active) if sl.length == L_choice]
            return int(self.rng.choice(cand))
        best_i, best_len = 0, self.active[0].length
        for i, sl in enumerate(self.active):
            if sl.length <= best_len:
                best_i, best_len = i, sl.length
        return best_i

    def _mu(self, t_pos: int) -> float:
        t_pos = max(1, int(t_pos))
        return math.sqrt(self.c0 / (self.A * t_pos))
    def _rho(self, t_pos: int) -> float:
        t_pos = max(1, int(t_pos))
        return math.sqrt(self.A * self.c0 / t_pos)
    def _rho_hat(self, t_pos: int) -> float:
        t_pos = max(1, int(t_pos))
        return 6.0 * self.n_hat * self.cT * self._rho(t_pos)
    def _compute_f_tilde(self, sl: _Slice) -> float:
        L = max(1, sl.L)
        sl.ensure_arrays(self.Pi)
        best_R = float(np.max(sl.S_pi / float(L))) if L > 0 else 1.0
        return float(best_R + self.c2 * self.A * self._mu(L))
    def _test1(self) -> bool:
        t_rel = self.t - self.tn
        for sl in self.active:
            if t_rel == sl.e and sl.L > 0:
                avg = sl.sum_R / float(sl.L)
                if avg >= (self.U_min + 9.0 * self._rho_hat(sl.length)):
                    return False
        return True
    def _test2(self) -> bool:
        L = self.epoch_len
        if L <= 0: return True
        avg_gap = self.g_minus_r_sum / float(L)
        return bool(avg_gap < 3.0 * self._rho_hat(L))

    def _sample_with_mu(self, base: BanditAlgorithm, arms: Sequence[Any], context: Any):
        choice = int(base.select_arm(arms, context=context))
        pmf = None
        if hasattr(base, "_last_pmf"):
            try:
                pmf = np.asarray(base._last_pmf, dtype=float)
            except Exception:
                pmf = None
        if pmf is None or pmf.size != len(arms) or not np.isfinite(pmf).all() or pmf.sum() <= 0:
            pmf = np.ones(len(arms), dtype=float) / float(len(arms))
        pmf = np.clip(pmf, 0.0, None)
        s = pmf.sum()
        pmf = pmf / s if s > 0 else np.ones(len(arms), dtype=float) / float(len(arms))
        mu = self._mu(max(1, self.epoch_len))
        mu_eff = min(mu, 1.0 / self.A - 1e-12)
        pmf_smoothed = (1.0 - self.A * mu_eff) * pmf + mu_eff
        pmf_smoothed = np.clip(pmf_smoothed, 1e-15, 1.0)
        pmf_smoothed /= pmf_smoothed.sum()
        a = int(self.rng.choice(len(arms), p=pmf_smoothed))
        return a, pmf_smoothed

    def _bernoulli(self, p: float) -> bool:
        return bool(self.rng.random() < float(np.clip(p, 0.0, 1.0)))
    def __str__(self) -> str:
        return self.name
