from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Sequence, Tuple, Optional
import copy

import numpy as np


@dataclass
class CallPlan:

    args: Tuple[Any, ...] = field(default_factory=tuple)
    kwargs: Dict[str, Any] = field(default_factory=dict)


@dataclass
class SimulationData:
    rewards: np.ndarray
    mean_rewards: np.ndarray
    best_rewards: np.ndarray
    arms: Any
    select_plans: List[List[CallPlan]]
    extras: Dict[str, Any] = field(default_factory=dict)


class BaseSimulation:
    def __init__(
        self,
        bandit_class,
        bandit_params: Dict[str, Any],
        T: int,
        change_prob: Optional[float],
        continuous: bool,
    ) -> None:
        self.bandit_class = bandit_class
        self.bandit_params = bandit_params
        self.T = int(T)
        self.change_prob = change_prob
        self.continuous = bool(continuous)

    def sample_change_points(self, rng: np.random.RandomState) -> List[int]:
        if self.change_prob is None or self.change_prob == 0:
            return [0, self.T]
        change_points = [0]
        for t in range(1, self.T):
            if rng.rand() <= self.change_prob:
                change_points.append(t)
        change_points.append(self.T)
        return change_points

    def precompute(self, seed: int) -> SimulationData:
        rng = np.random.RandomState(seed)
        change_points = self.sample_change_points(rng)
        data = self._run_simulation(rng, change_points)
        data.extras.setdefault("change_points", change_points)
        return data

    def _run_simulation(
        self, rng: np.random.RandomState, change_points: Sequence[int]
    ) -> SimulationData:
        raise NotImplementedError


class LinearBanditSimulation(BaseSimulation):

    def _run_simulation(
        self, rng: np.random.RandomState, change_points: Sequence[int]
    ) -> SimulationData:
        bandit = self.bandit_class(**self.bandit_params.copy())
        T = self.T
        K = bandit.num_actions

        rewards = np.zeros((T, K), dtype=float)
        means = np.zeros((T, K), dtype=float)
        best = np.zeros(T, dtype=float)

        theta_prev = getattr(bandit, "theta", np.zeros(1)).copy()
        theta_diffs = np.zeros(T, dtype=float)

        change_set = set(change_points[1:-1])

        for t in range(T):
            if not self.continuous and t in change_set:
                bandit.abrupt_change()
            elif self.continuous:
                bandit.gradual_change(change_rate=t / max(1, T))

            for a in range(K):
                rewards[t, a] = bandit.get_reward(a)
                means[t, a] = bandit.get_mean_reward(a)
            _, best_val = bandit.get_best_arm()
            best[t] = best_val

            theta_curr = getattr(bandit, "theta", theta_prev)
            theta_diffs[t] = float(np.linalg.norm(theta_curr - theta_prev))
            theta_prev = theta_curr.copy()

        arms = copy.deepcopy(bandit.arms)
        pt = float(theta_diffs.sum())

        select_plans: List[List[CallPlan]] = []
        for _ in range(T):
            select_plans.append(
                [
                    CallPlan(args=(arms, pt)),
                    CallPlan(args=(arms, change_points)),
                    CallPlan(args=(arms,)),
                ]
            )

        return SimulationData(
            rewards=rewards,
            mean_rewards=means,
            best_rewards=best,
            arms=arms,
            select_plans=select_plans,
            extras={"PT": pt},
        )


class KernelBanditSimulation(BaseSimulation):
    def _run_simulation(
        self, rng: np.random.RandomState, change_points: Sequence[int]
    ) -> SimulationData:
        bandit = self.bandit_class(**self.bandit_params.copy())
        T = self.T
        K = bandit.num_actions

        rewards = np.zeros((T, K), dtype=float)
        means = np.zeros((T, K), dtype=float)
        best = np.zeros(T, dtype=float)
        reward_diffs = np.zeros(T, dtype=float)

        prev_means = np.asarray(getattr(bandit, "reward_means", np.zeros(K)), dtype=float)
        change_set = set(change_points[1:-1])

        for t in range(T):
            if not self.continuous and t in change_set:
                bandit.abrupt_change()
            elif self.continuous:
                bandit.gradual_change()

            for a in range(K):
                rewards[t, a] = bandit.get_reward(a)
                means[t, a] = bandit.get_mean_reward(a)
            _, best_val = bandit.get_best_arm()
            best[t] = best_val

            current_means = np.asarray(bandit.reward_means, dtype=float)
            reward_diffs[t] = float(np.max(np.abs(current_means - prev_means)))
            prev_means = current_means.copy()

        pt = float(reward_diffs.sum())
        arms = copy.deepcopy(bandit.arms)

        select_plans: List[List[CallPlan]] = []
        for _ in range(T):
            select_plans.append(
                [
                    CallPlan(args=(arms, pt)),
                    CallPlan(args=(arms, change_points)),
                    CallPlan(args=(arms,)),
                ]
            )

        return SimulationData(
            rewards=rewards,
            mean_rewards=means,
            best_rewards=best,
            arms=arms,
            select_plans=select_plans,
            extras={"PT": pt},
        )


def _arms_identity(K: int) -> List[Dict[str, str]]:
    return [{"aid": f"a{i}"} for i in range(K)]


def _ctx_to_vw_dict(x: np.ndarray) -> Dict[str, float]:
    x = np.asarray(x, float).reshape(-1)
    return {"bias": 1.0, **{f"x{i+1}": float(x[i]) for i in range(x.shape[0])}}


def _pack_reward_params_for_path_length(bandit) -> np.ndarray:
    parts = [
        bandit.U.ravel(),
        bandit.V.ravel(),
        bandit.BIAS.ravel(),
        np.array([bandit.A_SIG, bandit.A_SIN, bandit.A_XPR], dtype=float),
    ]
    return np.concatenate(parts, axis=0).astype(float)


class ContextBanditSimulation(BaseSimulation):
    def _run_simulation(
        self, rng: np.random.RandomState, change_points: Sequence[int]
    ) -> SimulationData:
        bandit_params = self.bandit_params.copy()
        if "seed" in bandit_params:
            bandit_params["seed"] = rng.randint(0, 2**31 - 1)
        bandit = self.bandit_class(**bandit_params)

        T = self.T
        K = bandit.num_actions
        change_set = set(change_points[1:-1])

        try:
            bandit_arms = getattr(bandit, "arms", None)
            if bandit_arms is not None and len(bandit_arms) == K:
                arms_features = [np.asarray(a, dtype=float).ravel() for a in bandit_arms]
            else:
                arms_features = _arms_identity(K)
        except Exception:
            arms_features = _arms_identity(K)

        contexts: List[np.ndarray] = []
        context_dicts: List[Dict[str, float]] = []
        rewards = np.zeros((T, K), dtype=float)
        means = np.zeros((T, K), dtype=float)
        best = np.zeros(T, dtype=float)
        path_norms = np.zeros(T, dtype=float)

        prev_params = _pack_reward_params_for_path_length(bandit)

        for t in range(T):
            if not self.continuous and t in change_set:
                bandit.abrupt_change()
            elif self.continuous:
                bandit.gradual_change(change_rate=t / max(1, T))

            context = np.asarray(bandit.sample_context(), dtype=float)
            contexts.append(context.copy())
            context_dicts.append(_ctx_to_vw_dict(context))

            for a in range(K):
                means[t, a] = bandit.expected_reward(a, context)
                rewards[t, a] = bandit.get_reward(a)
            best[t] = float(means[t].max())

            cur_params = _pack_reward_params_for_path_length(bandit)
            path_norms[t] = float(np.linalg.norm(cur_params - prev_params))
            prev_params = cur_params

        select_plans: List[List[CallPlan]] = []
        for t in range(T):
            select_plans.append(
                [
                    CallPlan(args=(arms_features,), kwargs={"context": context_dicts[t]}),
                    CallPlan(args=(arms_features,), kwargs={"context": contexts[t]}),
                    CallPlan(args=(None,), kwargs={"context": context_dicts[t]}),
                    CallPlan(args=(), kwargs={"context": context_dicts[t]}),
                    CallPlan(args=(arms_features,)),
                ]
            )

        return SimulationData(
            rewards=rewards,
            mean_rewards=means,
            best_rewards=best,
            arms=arms_features,
            select_plans=select_plans,
            extras={
                "contexts": contexts,
                "context_dicts": context_dicts,
                "path_norm": float(path_norms.sum()),
            },
        )


__all__ = [
    "CallPlan",
    "SimulationData",
    "BaseSimulation",
    "LinearBanditSimulation",
    "KernelBanditSimulation",
    "ContextBanditSimulation",
]
