from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Union, Callable
import numpy as np

from src.core.bandit import BanditAlgorithm
from src.core.detection import detect_change, kl_divergence


try:
    from vowpalwabbit import pyvw
except Exception:
    pyvw = None


Feature = Union[Dict[str, Union[float, int, str]], Sequence[float], int, float, str]

def _fmt_namespace_cb(ns: str, x: Optional[Feature]) -> str:
    if x is None:
        return f"|{ns}"
    if isinstance(x, dict):
        feats: List[str] = []
        for k, v in x.items():
            if v is None:
                continue
            if isinstance(v, str):
                feats.append(f"{k}={v}")
            else:
                fv = float(v)
                if fv != 0.0:
                    feats.append(f"{k}:{fv:.10g}")
        return f"|{ns} " + " ".join(feats) if feats else f"|{ns}"
    if isinstance(x, (list, tuple, np.ndarray)):
        arr = np.asarray(x, dtype=float).ravel()
        feats = [f"f{i}:{val:.10g}" for i, val in enumerate(arr) if val != 0.0]
        return f"|{ns} " + " ".join(feats) if feats else f"|{ns}"
    if isinstance(x, (int, float, np.integer, np.floating)):
        return f"|{ns} id:{float(x):.10g}"
    return f"|{ns} tok={x}"

def _build_adf_example_cb(context: Optional[Feature], actions: Sequence[Feature]) -> List[str]:
    lines: List[str] = []
    if context is not None:
        lines.append("shared " + _fmt_namespace_cb("x", context))
    for a in actions:
        lines.append(_fmt_namespace_cb("a", a))
    return lines

def _build_adf_labeled_cb(
    context: Optional[Feature],
    actions: Sequence[Feature],
    chosen_index: int,
    cost: float,
    prob: float,
) -> List[str]:
    lines: List[str] = []
    if context is not None:
        lines.append("shared " + _fmt_namespace_cb("x", context))
    for i, a in enumerate(actions):
        if i == chosen_index:
            lines.append(f"0:{cost:.10g}:{max(min(prob,1.0), 1e-12):.12g} " + _fmt_namespace_cb("a", a))
        else:
            lines.append(_fmt_namespace_cb("a", a))
    return lines


@dataclass
class _VWConfig:
    interactions: Sequence[str] = ("ax",)
    seed: Optional[int] = 1
    quiet: bool = True
    reward_as_cost: bool = True
    cb_min_cost: float = 0.0
    cb_max_cost: float = 1.0
    cb_type: str = "mtr"

class _VWCBADF(BanditAlgorithm):
    def __init__(self, num_actions: int, horizon: int, vw_args: str, config: Optional[_VWConfig] = None):
        super().__init__(num_actions, horizon)
        if pyvw is None:
            raise ImportError("vowpalwabbit is required. Install with `pip install vowpalwabbit`.")
        self.cfg = config or _VWConfig()
        args = ["--cb_explore_adf", "--cb_type", self.cfg.cb_type]
        for inter in (self.cfg.interactions or ()):
            if isinstance(inter, str) and len(inter) >= 2:
                args += ["--interactions", inter]
        if self.cfg.seed is not None:
            args += ["--random_seed", str(int(self.cfg.seed))]
        if self.cfg.quiet:
            args += ["--quiet"]
        args += ["--cb_min_cost", str(self.cfg.cb_min_cost), "--cb_max_cost", str(self.cfg.cb_max_cost)]
        if vw_args:
            args += vw_args.split()
        self._vw_cmd: str = " ".join(args)
        self._vw = pyvw.vw(self._vw_cmd)

        self._last_context: Optional[Feature] = None
        self._last_actions: List[Feature] = []
        self._last_pmf: Optional[np.ndarray] = None
        self._last_choice: Optional[int] = None

        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0.0 for i in range(self.num_actions)}
        self.chosen_arm: int = 0

    def select_arm(self, arms: Sequence[Feature], context: Optional[Feature] = None) -> int:
        if not isinstance(arms, (list, tuple)) or len(arms) == 0:
            raise ValueError("`arms` must be a non-empty sequence of per-action features.")
        ex = _build_adf_example_cb(context, arms)
        pmf = np.asarray(self._vw.predict(ex), dtype=float)
        pmf = np.maximum(pmf, 0.0)
        pmf = pmf / pmf.sum() if pmf.sum() > 0 else np.ones(len(arms)) / float(len(arms))
        choice = int(np.random.choice(len(arms), p=pmf))
        self._last_context = context
        self._last_actions = list(arms)
        self._last_pmf = pmf
        self._last_choice = choice
        self.chosen_arm = choice
        return choice

    def update_statistics(self, arm: int, reward: float) -> None:
        if self._last_pmf is None or self._last_actions is None:
            raise RuntimeError("update_statistics called before select_arm; no cached decision.")
        idx = int(arm) if arm is not None else int(self._last_choice)
        if not (0 <= idx < len(self._last_actions)):
            raise IndexError("arm index out of range for last actions.")
        prob = float(self._last_pmf[idx]) if self._last_pmf is not None else 1.0
        cost = 1.0 - float(reward) if self.cfg.reward_as_cost else float(reward)
        cost = float(np.clip(cost, self.cfg.cb_min_cost, self.cfg.cb_max_cost))
        labeled = _build_adf_labeled_cb(self._last_context, self._last_actions, idx, cost, prob)
        self._vw.learn(labeled)
        self.TotalNumber[idx] += 1
        self.TotalSum[idx] += float(reward)
        self.SUMS[idx].append(float(reward))
        self._last_context = None
        self._last_actions = []
        self._last_pmf = None
        self._last_choice = None

    def reset(self):
        self._vw.finish()
        self._vw = pyvw.vw(self._vw_cmd)
        self._last_context = None
        self._last_actions = []
        self._last_pmf = None
        self._last_choice = None
        self.SUMS = {i: [] for i in range(self.num_actions)}
        self.TotalNumber = {i: 0 for i in range(self.num_actions)}
        self.TotalSum = {i: 0.0 for i in range(self.num_actions)}
        self.t = 0
        self.is_reset = True

    def re_init(self):
        self.reset()

    def finish(self):
        self._vw.finish()


class RegCB(_VWCBADF):
    def __init__(
        self,
        num_actions: int,
        horizon: int,
        mode: str = "elimination",
        mellowness: float = 0.01,
        cb_min_cost: float = 0.0,
        cb_max_cost: float = 1.0,
        interactions: Sequence[str] = ("ax",),
        seed: Optional[int] = 1,
    ):
        if mode not in {"elimination", "optimistic"}:
            raise ValueError("mode must be 'elimination' or 'optimistic'")
        algo = "--regcb" if mode == "elimination" else "--regcbopt"
        vw_args = f"{algo} --mellowness {mellowness}"
        cfg = _VWConfig(interactions=interactions, seed=seed,
                        cb_min_cost=cb_min_cost, cb_max_cost=cb_max_cost, cb_type="mtr")
        super().__init__(num_actions, horizon, vw_args, cfg)
    def __str__(self) -> str:
        return "RegCB"


class SquareCB(_VWCBADF):
    def __init__(
        self,
        num_actions: int,
        horizon: int,
        gamma_scale: float = 10.0,
        gamma_exponent: float = 0.5,
        elim: bool = True,
        mellowness: float = 0.01,
        cb_min_cost: float = 0.0,
        cb_max_cost: float = 1.0,
        interactions: Sequence[str] = ("ax",),
        seed: Optional[int] = 1,
    ):
        vw_args = f"--squarecb --gamma_scale {gamma_scale} --gamma_exponent {gamma_exponent}"
        if elim:
            vw_args += f" --elim --mellowness {mellowness}"
        cfg = _VWConfig(interactions=interactions, seed=seed,
                        cb_min_cost=cb_min_cost, cb_max_cost=cb_max_cost, cb_type="mtr")
        super().__init__(num_actions, horizon, vw_args, cfg)
    def __str__(self) -> str:
        return "SquareCB"


class Cover(_VWCBADF):
    def __init__(
        self,
        num_actions: int,
        horizon: int,
        m: int = 8,
        psi: float = 1.0,
        nounif: bool = False,
        first_only: bool = False,
        epsilon: Optional[float] = None,
        cb_type: str = "mtr",
        interactions: Sequence[str] = ("ax",),
        seed: Optional[int] = 1,
    ):
        if cb_type not in {"mtr", "dr", "ips"}:
            raise ValueError("cb_type must be one of {'mtr','dr','ips'}")
        vw_args = f"--cover {int(m)} --cb_type {cb_type} --psi {float(psi)}"
        if nounif:
            vw_args += " --nounif"
        if first_only:
            vw_args += " --first_only"
        if epsilon is not None:
            vw_args += f" --epsilon {float(epsilon)}"
        cfg = _VWConfig(interactions=interactions, seed=seed,
                        cb_min_cost=0.0, cb_max_cost=1.0, cb_type=cb_type)
        super().__init__(num_actions, horizon, vw_args, cfg)
    def __str__(self) -> str:
        return "Cover"



def _extract_numeric_vector(feat: Feature) -> Optional[np.ndarray]:
    if feat is None:
        return None
    if isinstance(feat, (list, tuple, np.ndarray)):
        try:
            arr = np.asarray(feat, dtype=float).ravel()
            if arr.size > 0 and np.all(np.isfinite(arr)):
                return arr
        except Exception:
            return None
    if isinstance(feat, dict):
        items = []
        for k, v in sorted(feat.items()):
            if isinstance(v, (int, float, np.integer, np.floating)):
                items.append(float(v))
            elif isinstance(v, (list, tuple, np.ndarray)):
                vv = np.asarray(v, dtype=float).ravel()
                if vv.size > 0 and np.all(np.isfinite(vv)):
                    items.extend(vv.tolist())
        if items:
            return np.asarray(items, dtype=float)
        return None
    if isinstance(feat, (int, float, np.integer, np.floating)):
        return np.asarray([float(feat)], dtype=float)
    return None

def _independent_indices_from_matrix(M: np.ndarray, tol: float = 1e-8) -> List[int]:
    K, d = M.shape
    basis: List[np.ndarray] = []
    inds: List[int] = []
    for i in range(K):
        v = M[i].astype(float, copy=True)
        for b in basis:
            denom = float(np.dot(b, b))
            if denom > 0.0:
                v = v - (float(np.dot(v, b)) / denom) * b
        if np.linalg.norm(v) > tol:
            inds.append(i)
            basis.append(v)
            if len(basis) >= d:
                break
    if not inds:
        inds = [0]
    return inds

class DALContext(BanditAlgorithm):

    def __init__(
        self,
        T: int,
        delta: float,
        noise_variance: float,
        *,
        base_factory: Callable[[], BanditAlgorithm],
        rng: Optional[np.random.Generator] = None,
        explore_coef: float = 1e-3,
        change_detector: Optional[Callable[..., bool]] = None,
    ):
        self._factory = base_factory
        self.base = self._factory()
        super().__init__(self.base.num_actions, T)

        self.delta = float(delta)
        self.noise_variance = float(noise_variance)
        self.rng = rng or np.random.default_rng()
        self.explore_coef = float(explore_coef)
        self.change_detector = change_detector or detect_change

        K = self.num_actions
        self.arm_counts = np.zeros(K, dtype=np.int64)
        self.arm_cums: List[List[float]] = [[] for _ in range(K)]

        self.indep_arms: Optional[List[int]] = None
        self.N_e: int = 0
        self.k = 1
        self.tau = 0

        self.init_params = dict(
            T=T, delta=delta, noise_variance=noise_variance,
            base_factory=base_factory, rng=self.rng,
            explore_coef=self.explore_coef, change_detector=self.change_detector,
        )

        self._last_logging_prob = 1.0
    def _maybe_init_indep_arms(self, arms: Sequence[Feature]) -> None:
        if self.indep_arms is not None:
            return
        try:
            self.base.all_arms = arms
            if hasattr(self.base, "get_indep_arms"):
                self.base.get_indep_arms()
                if getattr(self.base, "indep_arms", None) is not None:
                    self.indep_arms = [int(i) for i in self.base.indep_arms]
        except Exception:
            pass
        if self.indep_arms is None:
            rows: List[np.ndarray] = []
            for a in arms:
                v = _extract_numeric_vector(a)
                if v is None:
                    rows = []
                    break
                rows.append(v.astype(float, copy=False).ravel())
            if rows:
                d = max(r.size for r in rows)
                M = np.zeros((len(rows), d), dtype=float)
                for i, r in enumerate(rows):
                    M[i, :r.size] = r
                self.indep_arms = _independent_indices_from_matrix(M, tol=1e-8)
        if self.indep_arms is None:
            self.indep_arms = list(range(self.num_actions))
        self.N_e = max(1, len(self.indep_arms))

    def _exploration_frequency(self) -> int:
        logT = max(np.log(max(float(self.T), 2.0)), 1.0)
        alpha = float(self.explore_coef/logT) * np.sqrt(1000*self.k * max(1, self.N_e) * (logT / max(1.0, float(self.T))))
        alpha = max(alpha, 1e-12)
        return max(1, int(np.ceil(self.N_e / alpha)))

    def select_arm(self, arms, context=None):
        self._maybe_init_indep_arms(arms)

        base_choice = int(self.base.select_arm(arms, context=context))
        self.N_e=len(arms)
        self.indep_arms=list(range(self.N_e))
        explor_freq = self._exploration_frequency()
        phase = (self.t - self.tau) % explor_freq
        if phase < self.N_e:
            chosen = int(self.indep_arms[phase])
        else:
            chosen = base_choice
        chosen=base_choice

        if hasattr(self.base, "_last_pmf") and self.base._last_pmf is not None:
            pmf = np.asarray(self.base._last_pmf, float)
            if 0 <= chosen < pmf.size and np.isfinite(pmf[chosen]) and pmf[chosen] > 0:
                self._last_logging_prob = float(pmf[chosen])
            else:
                self._last_logging_prob = 1.0 / float(self.num_actions)
        else:
            self._last_logging_prob = 1.0 / float(self.num_actions)

        self.chosen_arm = chosen
        return chosen

    def update_statistics(self, arm, reward):
        self.base.update_statistics(int(arm), float(reward))

        a = int(arm)
        r = float(np.clip(reward, 0.0, 1.0))

        if self.arm_cums[a]:
            self.arm_cums[a].append(self.arm_cums[a][-1] + r)
        else:
            self.arm_cums[a].append(r)
        self.arm_counts[a] += 1
        nb = int(self.arm_counts[a])

        if nb > 2:
            changed = self.change_detector(
                nb,
                self.arm_cums[a],
                self.delta,
                divergence=lambda p, q, var=None: kl_divergence(p, q, mode="bernoulli"),
            )
            if changed:
                self.ChangePoints.append(self.t)
                self._restart()
                return

    def _restart(self):
        self.k += 1
        self.tau = self.t

        self.base = self._factory()

        K = self.num_actions
        self.arm_counts = np.zeros(K, dtype=np.int64)
        self.arm_cums = [[] for _ in range(K)]

        self.indep_arms = None
        self.N_e = 0

        self._last_logging_prob = 1.0

    def reset(self):
        self.__init__(**self.init_params)

    def re_init(self):
        self.reset()

    def __str__(self):
        return f"DAB({self.base})"


__all__ = [
    "RegCB", "SquareCB", "Cover",
    "DABContext",
]
