from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable
import numpy as np

from src.core.bandit import BanditAlgorithm


try:
    from vowpalwabbit import pyvw
except Exception:
    pyvw = None


class Policy:
    def predict(self, x: Any) -> int:
        raise NotImplementedError
    def predict_many(self, X: List[Any]) -> np.ndarray:
        return np.array([self.predict(x) for x in X], dtype=int)

class LinearArgmaxPolicy(Policy):
    def __init__(self, W: np.ndarray):
        W = np.asarray(W, float)
        if W.ndim != 2:
            raise ValueError("W must be 2-D (K,d)")
        self.W = W
        self.K, self.d = W.shape
    def predict(self, x: np.ndarray) -> int:
        
        v = np.asarray(list(x.values())[1:], float).reshape(-1)
        if v.shape[0] != self.d:
            raise ValueError(f"dim mismatch: got {v.shape[0]}, expected {self.d}")
        return int(np.argmax(self.W @ v))
    def predict_many(self, X: List[np.ndarray]) -> np.ndarray:
        vectors: List[np.ndarray] = []
        for x in X:
            if isinstance(x, dict):
                vec = np.asarray(list(x.values())[1:], float)
            else:
                vec = np.asarray(x, float).reshape(-1)
            if vec.shape[0] != self.d:
                raise ValueError(f"dim mismatch: got {vec.shape[0]}, expected {self.d}")
            vectors.append(vec.reshape(self.d, 1))
        Xmat = np.hstack(vectors) if vectors else np.zeros((self.d, 0))
        return np.argmax(self.W @ Xmat, axis=0).astype(int)

class ERMOracle:
    def argmax(self, weights_per_action: np.ndarray, contexts: List[Any], K: int) -> int:
        raise NotImplementedError

class FinitePolicyERM(ERMOracle):
    def __init__(self, policies: List[Policy]):
        self.policies = policies
    def argmax(self, weights_per_action: np.ndarray, contexts: List[Any], K: int) -> int:
        N = len(contexts)
        if N == 0 or len(self.policies) == 0:
            return 0
        preds = np.vstack([p.predict_many(contexts) for p in self.policies])
        gather = weights_per_action[np.arange(N)[None, :], preds]
        scores = gather.sum(axis=1)
        return int(np.argmax(scores))


def _fmt_ctx_ns(x: Union[Dict[str, float], np.ndarray]) -> str:
    if isinstance(x, dict):
        feats = [f"{k}:{float(v):.10g}" for k, v in x.items() if v is not None and float(v) != 0.0]
        return "|x " + " ".join(feats) if feats else "|x"
    arr = np.asarray(x, float).ravel()
    feats = [f"f{i}:{val:.10g}" for i, val in enumerate(arr) if val != 0.0]
    return "|x " + " ".join(feats) if feats else "|x"

def _fmt_action_ns(a: int) -> str:
    return f"|a aid=a{a}"

class VwCsoaaPolicy(Policy):
    def __init__(self, vw, K: int):
        self.vw = vw
        self.K = int(K)
    def predict(self, x: Any) -> int:
        lines = ["shared " + _fmt_ctx_ns(x)]
        lines += [_fmt_action_ns(a) for a in range(self.K)]
        pred = self.vw.predict(lines)
        return int(pred)
    def predict_many(self, X: List[Any]) -> np.ndarray:
        return np.array([self.predict(x) for x in X], dtype=int)

class VWCsoaaERM(ERMOracle):
    def __init__(
        self,
        K: int,
        interactions: Sequence[str] = ("ax",),
        seed: Optional[int] = 205,
        quiet: bool = True,
        passes: int = 1,
        l2: float = 1e-8,
        learning_rate: float = 0.5,
        grow_bank: bool = True,
        register_policy: Optional[Callable[[Optional[Policy]], None]] = None,
    ):
        if pyvw is None:
            raise ImportError("VWCsoaaERM requires vowpalwabbit.")
        self.K = int(K)
        self.interactions = tuple(interactions)
        self.seed = seed
        self.quiet = quiet
        self.passes = int(passes)
        self.l2 = float(l2)
        self.lr = float(learning_rate)
        self.grow_bank = bool(grow_bank)
        self._register = register_policy

    def _build_workspace(self):
        args: List[str] = ["--csoaa_ldf", "mc", "--indexing", "0", "--holdout_off"]
        for inter in self.interactions:
            args += ["--interactions", inter]
        args += ["--random_seed", str(int(self.seed))]
        if self.quiet:
            args += ["--quiet"]
        if self.l2 > 0:
            args += ["--l2", str(self.l2)]
        if self.lr > 0:
            args += ["--learning_rate", str(self.lr)]
        return pyvw.vw(" ".join(args))

    def argmax(self, weights_per_action: np.ndarray, contexts: List[Any], K: int) -> int:
        assert K == self.K, "ERM K mismatch"
        N = len(contexts)
        if N == 0:
            if self._register: self._register(None)
            return 0
        vw = self._build_workspace()
        for _p in range(self.passes):
            for t in range(N):
                costs = -np.asarray(weights_per_action[t], float)
                lines = ["shared " + _fmt_ctx_ns(contexts[t])]
                for a in range(self.K):
                    lines.append(f"{a}:{float(costs[a]):.10g} " + _fmt_action_ns(a))
                vw.learn(lines)
        pol = VwCsoaaPolicy(vw, self.K)
        if self.grow_bank and self._register is not None:
            self._register(pol)
            return -1
        return -1


@dataclass
class IntervalView:
    start: int
    end: int
    def length(self) -> int:
        return max(0, self.end - self.start)

class EpochLog:
    def __init__(self):
        self.X: List[Any] = []
        self.a: List[int] = []
        self.r: List[float] = []
        self.p: List[float] = []
    def append(self, x: Any, a: int, r: float, p_chosen: float) -> None:
        self.X.append(x); self.a.append(int(a)); self.r.append(float(r)); self.p.append(float(p_chosen))
    def view(self, start: int, end: int) -> IntervalView:
        return IntervalView(start, end)
    def slice_data(self, iv: IntervalView) -> Tuple[List[Any], np.ndarray, np.ndarray, np.ndarray]:
        X = self.X[iv.start:iv.end]
        a = np.array(self.a[iv.start:iv.end], int)
        r = np.array(self.r[iv.start:iv.end], float)
        p = np.array(self.p[iv.start:iv.end], float)
        return X, a, r, p

class ADAILTCBPlusCore:
    C_for_OP: float = 1.2e7
    D1: float = 6400.0
    D2: float = 800.0
    D4: float = 6400.0
    D5: float = 800.0

    def __init__(self, K: int, policies: List[Policy], T: int,
                 delta: float = 0.05, erm: Optional[ERMOracle] = None,
                 rng: Optional[np.random.Generator] = None):
        self.K = int(K)
        self.policies: List[Policy] = list(policies)
        self.P = len(self.policies)
        self.T = int(T)
        self.delta = float(delta)
        self.rng = rng or np.random.default_rng()
        if erm is None:
            self.erm = FinitePolicyERM(self.policies)
        else:
            if isinstance(erm, VWCsoaaERM):
                erm._register = self._register_new_policy
            self.erm = erm

        self.epoch_idx = 0
        self.block_idx = 0
        self.t_global = 0
        self.t_epoch_start = 0

        P_for_C0 = max(1, len(self.policies))
        self.C0 = np.log(8 * (self.T ** 3) * (P_for_C0 ** 2) / max(1e-12, self.delta))
        self.L = int(np.ceil(4 * self.K * self.C0))
        self.Q_by_block: List[Dict[int, float]] = []
        self.nu_by_block: List[float] = []
        self.active_replays: List[Tuple[int, int, int, int]] = []
        self.log = EpochLog()
        self._start_new_epoch(reset_only=True)

    def _register_new_policy(self, pol: Optional[Policy]) -> None:
        if pol is None: return
        self.policies.append(pol)
        self.P = len(self.policies)


    def select_action(self, x: Any) -> Tuple[int, Dict[str, Any]]:
        self._maybe_start_replay_now()
        active_ms = [m for (m, j0, s, e) in self.active_replays if s <= self.t_global < e]
        distinct_ms = sorted(set(active_ms))
        if len(distinct_ms) == 0:
            m_used = self.block_idx
            Q = self.Q_by_block[m_used]; nu = self.nu_by_block[m_used]
            probs = self._mixture_probs_single(Q, nu, x)
            a = int(self._sample(probs))
            return a, {"probs": probs, "p_chosen": float(probs[a]), "source_block": m_used, "nu": nu}
        else:
            m_used = int(self.rng.choice(distinct_ms))
            probs_used = self._mixture_probs_single(self.Q_by_block[m_used], self.nu_by_block[m_used], x)
            mix = np.mean(
                [self._mixture_probs_single(self.Q_by_block[m], self.nu_by_block[m], x) for m in distinct_ms],
                axis=0,
            )
            a = int(self._sample(probs_used))
            return a, {"probs": mix, "p_chosen": float(mix[a]), "source_block": m_used, "nu": self.nu_by_block[m_used]}

    def observe(self, x: Any, a: int, r: float, p_chosen: float) -> None:
        self.log.append(x, a, r, p_chosen)
        self.t_global += 1
        finished = [(m, j0, s, e) for (m, j0, s, e) in self.active_replays if e <= self.t_global]
        if finished:
            self.active_replays = [(m, j0, s, e) for (m, j0, s, e) in self.active_replays if e > self.t_global]
            for (m, j0, s, e) in finished:
                A = self.log.view(s - self.t_epoch_start, e - self.t_epoch_start)
                B = self._union_block_view(self.block_idx - 1)
                if self._end_of_replay_test(A, B, m):
                    self._restart_epoch()
                    return
        if self.t_global >= self._current_block_end_time():
            if self.block_idx > 0:
                Bj = self._union_block_view(self.block_idx)
                failed = False
                for k in range(0, self.block_idx):
                    Bk = self._union_block_view(k)
                    if self._end_of_block_test(Bj, Bk, k):
                        failed = True; break
                if failed:
                    self._restart_epoch()
                    return
            self.block_idx += 1
            self.active_replays.clear()
            self._ensure_block_params(self.block_idx)
            self._solve_OP_for_block(self.block_idx)


    def _start_new_epoch(self, reset_only: bool = False) -> None:
        if not reset_only:
            self.epoch_idx += 1
        self.block_idx = 0
        self.t_epoch_start = self.t_global
        self.active_replays.clear()
        self.log = EpochLog()
        self._ensure_block_params(0)
        if len(self.policies) == 0:
            self._set_Q_for_block(0, {})
        else:
            pid0 = int(self.rng.integers(len(self.policies)))
            self._set_Q_for_block(0, {pid0: 1.0})

    def _restart_epoch(self) -> None:
        self._start_new_epoch()

    def _current_block_end_time(self) -> int:
        return self.t_epoch_start + sum(self._block_len(j) for j in range(0, self.block_idx + 1))

    def _block_len(self, j: int) -> int:
        if j <= 1: return self.L
        return (2 ** (j - 1)) * self.L

    def _union_block_view(self, k: int) -> IntervalView:
        if k < 0: return IntervalView(0, 0)
        length = sum(self._block_len(j) for j in range(0, k + 1))
        end = min(self.log_len(), length)
        return self.log.view(0, end)

    def log_len(self) -> int:
        return len(self.log.a)

    def _maybe_start_replay_now(self) -> None:
        j = self.block_idx
        if j == 0: return
        terms = np.array([2 ** (-m / 2) for m in range(0, j)], float)
        alpha_j = (1.0 / self.L) * (2.0 ** (-j / 2.0)) * float(terms.sum())
        if self.rng.random() < alpha_j:
            probs = terms / terms.sum()
            m = int(self.rng.choice(np.arange(0, j), p=probs))
            start = self.t_global
            end = start + (2 ** m) * self.L
            self.active_replays.append((m, j, start, end))

    def _ensure_block_params(self, j: int) -> None:
        while len(self.nu_by_block) <= j:
            jj = len(self.nu_by_block)
            scale = (2 ** jj) * self.L
            nu = float(np.sqrt(self.C0 / (self.K * max(1.0, scale))))
            nu = min(nu, 1.0 / max(1, self.K))
            self.nu_by_block.append(nu)
        while len(self.Q_by_block) <= j:
            self.Q_by_block.append({})

    def _set_Q_for_block(self, j: int, Q: Dict[int, float]) -> None:
        Q = {pid: w for pid, w in Q.items() if 0 <= pid < len(self.policies) and w > 0}
        if len(Q) == 0:
            self.Q_by_block[j] = {}
            return
        total = float(sum(Q.values()))
        self.Q_by_block[j] = {pid: float(w) / total for pid, w in Q.items()}


    def _erm_best_policy_on_interval(self, iv: IntervalView) -> Optional[int]:
        X, a, r, p = self.log.slice_data(iv)
        N = len(X)
        if N == 0: return None
        weights = np.zeros((N, self.K), float)
        w = r / np.maximum(p, 1e-12)
        weights[np.arange(N), a] = w
        idx = self.erm.argmax(weights, X, self.K)
        if idx == -1:
            if len(self.policies) > 0:
                return len(self.policies) - 1
            return None
        return int(idx)

    def _solve_OP_for_block(self, j: int, max_iter: int = 2000) -> None:
        if j == 0:
            return
        Iprev = self._union_block_view(j - 1)
        if Iprev.length() == 0:
            self._set_Q_for_block(j, {})
            return

        idx_erm = self._erm_best_policy_on_interval(Iprev)

        if len(self.policies) == 0:
            self._set_Q_for_block(j, {})
            return
        Rhat, best_idx_ipw = self._ips_rewards_on_interval(Iprev)
        if Rhat.size == 0:
            self._set_Q_for_block(j, {})
            return

        best_idx = int(idx_erm) if (idx_erm is not None and 0 <= int(idx_erm) < len(self.policies)) else int(best_idx_ipw)
        nu = self.nu_by_block[j]
        C = self.C_for_OP
        reg = Rhat[best_idx] - Rhat
        Q: Dict[int, float] = {best_idx: 1.0}

        for _ in range(max_iter):
            RegExp = float(sum(Q.get(pid, 0.0) * reg[pid] for pid in Q))
            worst_pid, worst_violation = self._max_variance_violation(Iprev, Q, nu, reg, C)
            need_reg = (RegExp > 2 * C * self.K * nu)
            need_var = (worst_violation > 0.0)
            if not need_reg and not need_var:
                break
            add_pid = int(worst_pid) if need_var else int(best_idx)
            step = max(0.005, 0.05 / max(1.0, len(Q)))
            Q = {pid: (1.0 - step) * w for pid, w in Q.items()}
            Q[add_pid] = Q.get(add_pid, 0.0) + step

        self._set_Q_for_block(j, Q)

    def _ips_rewards_on_interval(self, iv: IntervalView) -> Tuple[np.ndarray, int]:
        X, a, r, p = self.log.slice_data(iv)
        N = len(X); P = len(self.policies)
        if N == 0 or P == 0:
            return np.zeros(P, float), 0
        preds = np.vstack([pol.predict_many(X) for pol in self.policies]).astype(int)
        w = (r / np.maximum(1e-12, p)).astype(float)
        M = (preds == a[None, :])
        Rhat = (M * w[None, :]).sum(axis=1) / float(N)
        best_idx = int(np.argmax(Rhat)) if Rhat.size > 0 else 0
        return Rhat, best_idx

    def _variance_terms(self, iv: IntervalView, Q: Dict[int, float], nu: float) -> np.ndarray:
        X, a, r, p = self.log.slice_data(iv)
        N = len(X); P = len(self.policies)
        if N == 0 or P == 0:
            return np.zeros(P, float)
        preds = np.vstack([pol.predict_many(X) for pol in self.policies]).astype(int)
        supp = [pid for pid in Q.keys() if 0 <= pid < P]
        wQ = np.array([Q[pid] for pid in supp], float)
        mass = np.zeros((N, self.K), float)
        idx = np.arange(N)
        for si, pid in enumerate(supp):
            mass[idx, preds[pid]] += wQ[si]
        mass_for_pi = mass[idx[None, :], preds]
        qnu = nu + (1.0 - self.K * nu) * mass_for_pi
        V = (1.0 / np.maximum(1e-12, qnu)).mean(axis=1)
        return V

    def _max_variance_violation(self, iv: IntervalView, Q: Dict[int, float], nu: float, reg: np.ndarray, C: float) -> Tuple[int, float]:
        V = self._variance_terms(iv, Q, nu)
        thresholds = 2.0 * self.K + reg / (C * max(1e-12, nu))
        violations = V - thresholds
        worst_pid = int(np.argmax(violations)) if violations.size > 0 else 0
        return worst_pid, float(max(0.0, violations[worst_pid] if violations.size > 0 else 0.0))

    def _end_of_replay_test(self, A: IntervalView, B_prev: IntervalView, m: int) -> bool:
        if A.length() == 0 or B_prev.length() == 0:
            return False
        RA, bestA = self._ips_rewards_on_interval(A)
        RB, bestB = self._ips_rewards_on_interval(B_prev)
        regA = RA[bestA] - RA; regB = RB[bestB] - RB
        Qm = self.Q_by_block[m]; nu_m = self.nu_by_block[m]
        VA = self._variance_terms(A, Qm, nu_m)
        VB = self._variance_terms(B_prev, Qm, nu_m)
        D1, D2, K = self.D1, self.D2, self.K
        cond1 = (regA - 4.0 * regB) >= (D1 * K * nu_m)
        cond2 = (regB - 4.0 * regA) >= (D1 * K * nu_m)
        cond3 = (VA - 41.0 * VB) >= (D2 * K)
        return bool(np.any(cond1 | cond2 | cond3))

    def _end_of_block_test(self, Bj: IntervalView, Bk: IntervalView, k: int) -> bool:
        if Bj.length() == 0 or Bk.length() == 0:
            return False
        Rj, bestJ = self._ips_rewards_on_interval(Bj)
        Rk, bestK = self._ips_rewards_on_interval(Bk)
        regJ = Rj[bestJ] - Rj; regK = Rk[bestK] - Rk
        nu_k = self.nu_by_block[k]
        Qkp1 = self.Q_by_block[min(k + 1, len(self.Q_by_block) - 1)]
        nu_kp1 = self.nu_by_block[min(k + 1, len(self.nu_by_block) - 1)]
        Vj = self._variance_terms(Bj, Qkp1, nu_kp1)
        Vk = self._variance_terms(Bk, Qkp1, nu_kp1)
        D4, D5, K = self.D4, self.D5, self.K
        cond1 = (regJ - 4.0 * regK) >= (D4 * K * nu_k)
        cond2 = (regK - 4.0 * regJ) >= (D4 * K * nu_k)
        cond3 = (Vj   - 41.0 * Vk)  >= (D5 * K)
        return bool(np.any(cond1 | cond2 | cond3))

    def _mixture_probs_single(self, Q: Dict[int, float], nu: float, x: Any) -> np.ndarray:
        mass = np.zeros(self.K, float)
        for pid, w in Q.items():
            if 0 <= pid < len(self.policies):
                a = self.policies[pid].predict(x)
                mass[a] += w
        probs = nu + (1.0 - self.K * nu) * mass
        probs = np.clip(probs, 1e-12, 1.0)
        probs /= probs.sum()
        return probs

    def _sample(self, probs: np.ndarray) -> int:
        return int(self.rng.choice(self.K, p=probs))


class ADAILTCBPlusBandit(BanditAlgorithm):
    def __init__(
        self,
        num_actions: int,
        horizon: int,
        erm_mode: str = "finite",
        policies: Optional[List[Policy]] = None,
        vw_erm_kwargs: Optional[Dict[str, Any]] = None,
        rng: Optional[np.random.Generator] = None,
        delta: float = 0.05,
        featurize_ctx: Optional[Callable[[Any], np.ndarray]] = None,
    ):
        super().__init__(num_actions, horizon)
        self.K = int(num_actions)
        self.delta = float(delta)
        self._feat = featurize_ctx
        self._rng = rng or np.random.default_rng()
        if erm_mode not in {"finite","vw"}:
            raise ValueError("erm_mode must be 'finite' or 'vw'")
        if erm_mode == "finite":
            if not policies:
                raise ValueError("erm_mode='finite' requires `policies`.")
            erm = FinitePolicyERM(policies)
            self._core = ADAILTCBPlusCore(K=self.K, policies=policies, T=horizon, delta=delta, erm=erm, rng=rng)
        else:
            if pyvw is None:
                raise ImportError("VW ERM selected but vowpalwabbit is not installed.")
            vw_kwargs = dict(
                interactions=("ax",),
                seed=1,
                quiet=True,
                passes=1,
                l2=1e-8,
                learning_rate=0.5,
                grow_bank=False,
                register_policy=None,
            )
            if vw_erm_kwargs: vw_kwargs.update(vw_erm_kwargs)
            erm = VWCsoaaERM(K=self.K, **vw_kwargs)
            seed_bank = policies or []
            self._core = ADAILTCBPlusCore(K=self.K, policies=seed_bank, T=horizon, delta=delta, erm=erm, rng=rng)

        self.SUMS = {i: [] for i in range(self.K)}
        self.TotalNumber = {i: 0 for i in range(self.K)}
        self.TotalSum = {i: 0.0 for i in range(self.K)}
        self.chosen_arm = 0
        self._last_x: Any = None
        self._last_p: float = 1.0

    def _to_vec(self, context: Any) -> Any:
        if self._feat is not None:
            return self._feat(context)
        return context

    def select_arm(self, arms: Sequence[Any] = None, context: Any = None) -> int:
        if context is None:
            raise ValueError("ADAILTCBPlusBandit requires a `context`.")
        x = self._to_vec(context)
        a, info = self._core.select_action(x)
        self._last_x = x
        self._last_p = float(info["p_chosen"])
        self.chosen_arm = int(a)
        return int(a)

    def update_statistics(self, arm: int, reward: float) -> None:
        a = int(arm); r = float(reward)
        self._core.observe(self._last_x, a, r, self._last_p)
        self.TotalNumber[a] += 1
        self.TotalSum[a] += r
        self.SUMS[a].append(r)

    def reset(self):
        policies_copy = list(self._core.policies)
        if isinstance(self._core.erm, FinitePolicyERM):
            self.__init__(num_actions=self.K, horizon=self.T, erm_mode="finite",
                          policies=policies_copy, delta=self.delta, featurize_ctx=self._feat,rng=self._rng)
        else:
            self.__init__(num_actions=self.K, horizon=self.T, erm_mode="vw",
                          policies=policies_copy, delta=self.delta, featurize_ctx=self._feat,rng=self._rng)

    def __str__(self):
        return "ADA-ILTCB+ (finite-Π, VW-ERM supported)"


__all__ = [
    "Policy", "LinearArgmaxPolicy",
    "FinitePolicyERM", "VWCsoaaERM",
    "ADAILTCBPlusCore", "ADAILTCBPlusBandit",
]
