from __future__ import annotations

import numpy as np
from typing import Optional, List, Tuple, Literal, Union

from src.utils import sigmoid

from src.bandits.NonStationaryBandit import NonStationaryBandit

ContextMode = Literal["finite", "gaussian"]


class NSContextBandit(NonStationaryBandit):
    def __init__(
        self,
        *,
        num_actions: int,
        noise_variance: float,
        d: int,
        d_c: int,
        reward_bound: float,
        context_mode: Optional[ContextMode] = None,
        finite_contexts: Optional[np.ndarray] = None,
        finite_probs: Optional[np.ndarray] = None,
        gaussian_mean: Optional[np.ndarray] = None,
        gaussian_scale: Optional[np.ndarray] = None,
        action_bound: float = 1.0,
        continuous: Union[bool, str] = False,
        finite_N_min: int = 1000,
        finite_N_max: int = 1000,
        ctx_mean_std: float = 0.5,
        ctx_scale_low: float = 0.3,
        ctx_scale_high: float = 1.0,
        seed: Optional[int] = None,
        mean_low: float = 0.0,
        mean_high: float = 1.0,
    ):
        self.num_actions = int(num_actions)
        self.d = int(d)
        self.d_c = int(d_c)
        self.action_bound = float(action_bound)
        self.continuous = continuous
        self.rng = np.random.default_rng(seed)

        self.finite_N_min = int(max(2, finite_N_min))
        self.finite_N_max = int(max(self.finite_N_min, finite_N_max))
        self.ctx_mean_std = float(ctx_mean_std)
        self.ctx_scale_low = float(ctx_scale_low)
        self.ctx_scale_high = float(ctx_scale_high)

        if context_mode is None:
            self.context_mode: ContextMode = "finite" if bool(self.rng.integers(0, 2)) else "gaussian"
        else:
            self.context_mode = context_mode

        if self.context_mode == "finite":
            if finite_contexts is None:
                N = int(self.rng.integers(self.finite_N_min, self.finite_N_max + 1))
                self.finite_contexts = self._sample_finite_contexts(N)
            else:
                self.finite_contexts = np.asarray(finite_contexts, dtype=np.float64)
                assert self.finite_contexts.ndim == 2 and self.finite_contexts.shape[1] == self.d_c, \
                    "finite_contexts must be (N, d_c)"
            N = self.finite_contexts.shape[0]
            if finite_probs is None:
                self.finite_probs = self.rng.dirichlet(np.ones(N, dtype=float))
            else:
                p = np.asarray(finite_probs, float).reshape(-1)
                assert p.shape[0] == N, "finite_probs length must match number of finite_contexts"
                p = np.clip(p, 1e-12, None)
                self.finite_probs = p / p.sum()
        else:
            if gaussian_mean is None:
                gaussian_mean = self.rng.normal(0.0, self.ctx_mean_std, size=self.d_c)
            if gaussian_scale is None:
                gaussian_scale = self.rng.uniform(self.ctx_scale_low, self.ctx_scale_high, size=self.d_c)
            self.gaussian_mean = np.asarray(gaussian_mean, float).reshape(-1)
            self.gaussian_scale = np.asarray(gaussian_scale, float).reshape(-1)
            assert self.gaussian_mean.shape[0] == self.d_c
            assert self.gaussian_scale.shape[0] == self.d_c
            assert np.all(self.gaussian_scale > 0), "gaussian_scale must be > 0"

        self._sample_reward_params()
        self._cache_init_reward_params()

        self.current_context = np.ones(self.d_c, dtype=np.float64)
        self.last_context = self.current_context.copy()

        self.init_params = dict(
            num_actions=num_actions, noise_variance=noise_variance, d=d, d_c=d_c,
            reward_bound=reward_bound, context_mode=self.context_mode,
            finite_contexts=(self.finite_contexts if self.context_mode == "finite" else None),
            finite_probs=(self.finite_probs if self.context_mode == "finite" else None),
            gaussian_mean=(self.gaussian_mean if self.context_mode == "gaussian" else None),
            gaussian_scale=(self.gaussian_scale if self.context_mode == "gaussian" else None),
            action_bound=action_bound, continuous=continuous,
            finite_N_min=self.finite_N_min, finite_N_max=self.finite_N_max,
            ctx_mean_std=self.ctx_mean_std, ctx_scale_low=self.ctx_scale_low,
            ctx_scale_high=self.ctx_scale_high, seed=seed,
            mean_low=mean_low, mean_high=mean_high,
        )

        super().__init__(
            arm_type="bernoulli",
            num_actions=num_actions,
            noise_variance=noise_variance,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            continuous=continuous,
        )

    def set_arms(self) -> List[np.ndarray]:
        arms: List[np.ndarray] = []
        for _ in range(self.num_actions):
            v = self.rng.normal(size=self.d)
            v /= (np.linalg.norm(v) + 1e-12)
            v *= np.sqrt(self.action_bound)
            arms.append(v.astype(np.float64))
        return arms

    def get_mean_reward(self, action: int) -> float:
        x = self.current_context
        return float(self._nonlinear_mean(action, x))

    def sample_context(self) -> np.ndarray:
        if self.context_mode == "finite":
            idx = int(self.rng.choice(len(self.finite_contexts), p=self.finite_probs))
            x = self.finite_contexts[idx]
        else:
            x = self.gaussian_mean + self.gaussian_scale * self.rng.normal(size=self.d_c)
        self.set_context(x)
        self.last_context = x.copy()
        return x

    def set_context(self, x: np.ndarray) -> None:
        x = np.asarray(x, dtype=np.float64).reshape(-1)
        assert x.shape[0] == self.d_c, f"context dim mismatch: got {x.shape[0]}, expected {self.d_c}"
        self.current_context = x
        self.reward_means = self.get_reward_means()

    def expected_reward(self, action: int, x: np.ndarray) -> float:
        return float(self._nonlinear_mean(int(action), np.asarray(x, float).reshape(-1)))
    def abrupt_change(
        self,
        new_reward_params: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float]] = None,
        *,
        change_reward: Optional[bool] = None,
        change_context: Optional[bool] = None,
    ) -> None:
        if change_reward is None:
            change_reward = bool(self.rng.integers(0, 2))
        if change_context is None:
            change_context = bool(self.rng.integers(0, 2))
        if (not change_reward) and (not change_context):
            if bool(self.rng.integers(0, 2)):
                change_reward = True
            else:
                change_context = True

        if change_reward:
            if new_reward_params is not None:
                U, V, BIAS, A_SIG, A_SIN, A_XPR = new_reward_params
                assert U.shape == (self.num_actions, self.d_c)
                assert V.shape == (self.num_actions, self.d_c)
                assert BIAS.shape == (self.num_actions,)
                self.U = np.asarray(U, float)
                self.V = np.asarray(V, float)
                self.BIAS = np.asarray(BIAS, float)
                self.A_SIG = float(A_SIG)
                self.A_SIN = float(A_SIN)
                self.A_XPR = float(A_XPR)
            else:
                self._sample_reward_params()
            self._cache_init_reward_params()
            self._clear_target_reward_params()

        if change_context:
            if self.context_mode == "finite":
                N = self.finite_contexts.shape[0]
                p = self.rng.dirichlet(np.ones(N, dtype=float))
                self.finite_probs = np.clip(p, 1e-12, None)
                self.finite_probs /= self.finite_probs.sum()
            else:
                self.gaussian_mean = self.rng.normal(0.0, self.ctx_mean_std, size=self.d_c)
                self.gaussian_scale = self.rng.uniform(self.ctx_scale_low, self.ctx_scale_high, size=self.d_c)

        self.reward_means = self.get_reward_means()

    def gradual_change(self, change_rate: float = 0.01) -> None:
        if self._target_reward is None:
            if self.continuous == "AGGRESIVE":
                U_t, V_t, BIAS_t, AS_t, AN_t, AX_t = self._worst_direction_target()
            else:
                U_t, V_t, BIAS_t, AS_t, AN_t, AX_t = self._sample_reward_params(return_only=True)
            self._target_reward = (U_t, V_t, BIAS_t, AS_t, AN_t, AX_t)

        U0, V0, B0, AS0, AN0, AX0 = self._init_reward
        Ut, Vt, Bt, ASt, ANt, AXt = self._target_reward

        self.U     = (1.0 - change_rate) * U0 + change_rate * Ut
        self.V     = (1.0 - change_rate) * V0 + change_rate * Vt
        self.BIAS  = (1.0 - change_rate) * B0 + change_rate * Bt
        self.A_SIG = (1.0 - change_rate) * AS0 + change_rate * ASt
        self.A_SIN = (1.0 - change_rate) * AN0 + change_rate * ANt
        self.A_XPR = (1.0 - change_rate) * AX0 + change_rate * AXt

        self.A_SIG = float(np.clip(self.A_SIG, 0.0, 1.0))
        self.A_SIN = float(np.clip(self.A_SIN, 0.0, 1.0))
        self.A_XPR = float(np.clip(self.A_XPR, 0.0, 1.0))

        self.reward_means = self.get_reward_means()

    def get_P_T(self, T: int) -> float:
        if self._target_reward is None:
            return 0.0
        U0, V0, B0, AS0, AN0, AX0 = self._init_reward
        Ut, Vt, Bt, ASt, ANt, AXt = self._target_reward
        res = 0.0
        Uc, Vc, Bc, ASc, ANc, AXc = U0.copy(), V0.copy(), B0.copy(), AS0, AN0, AX0
        for t in range(1, int(T)):
            alpha = t / T
            Utmp = (1.0 - alpha) * U0 + alpha * Ut
            Vtmp = (1.0 - alpha) * V0 + alpha * Vt
            Btmp = (1.0 - alpha) * B0 + alpha * Bt
            AStmp = (1.0 - alpha) * AS0 + alpha * ASt
            ANtmp = (1.0 - alpha) * AN0 + alpha * ANt
            AXtmp = (1.0 - alpha) * AX0 + alpha * AXt
            res += (np.linalg.norm(Utmp - Uc)
                    + np.linalg.norm(Vtmp - Vc)
                    + np.linalg.norm(Btmp - Bc)
                    + abs(AStmp - ASc) + abs(ANtmp - ANc) + abs(AXtmp - AXc))
            Uc, Vc, Bc, ASc, ANc, AXc = Utmp, Vtmp, Btmp, AStmp, ANtmp, AXtmp
        return float(res)

    def re_init(self) -> None:
        super().re_init()
        self._cache_init_reward_params()
        self._clear_target_reward_params()
        self.set_context(self.current_context)
    def _nonlinear_mean(self, a: int, x: np.ndarray) -> float:
        z1 = float(np.dot(self.U[a], x))
        z2 = float(np.dot(self.V[a], x))
        cross = float(x[1] * x[2]) if self.d_c >= 3 else 0.0
        mu = float(self.BIAS[a] + self.A_SIG * sigmoid(z1)
                   + self.A_SIN * np.sin(z2) + self.A_XPR * cross)
        return float(np.clip(mu, 0.0, 1.0))

    def _sample_reward_params(self, return_only: bool = False):
        U = self.rng.normal(size=(self.num_actions, self.d_c))
        V = self.rng.normal(size=(self.num_actions, self.d_c))
        BIAS = self.rng.uniform(0.3, 0.7, size=self.num_actions)
        A_SIG = float(self.rng.uniform(0.25, 0.45))
        A_SIN = float(self.rng.uniform(0.15, 0.35))
        A_XPR = float(self.rng.uniform(0.10, 0.25))
        if return_only:
            return U, V, BIAS, A_SIG, A_SIN, A_XPR
        self.U, self.V, self.BIAS = U, V, BIAS
        self.A_SIG, self.A_SIN, self.A_XPR = A_SIG, A_SIN, A_XPR

    def _worst_direction_target(self):
        U0, V0, B0, AS0, AN0, AX0 = self._init_reward
        U_t = -U0
        V_t = -V0
        B_t = 1.0 - B0
        AS_t = 0.9 if AS0 < 0.5 else 0.3
        AN_t = 0.9 if AN0 < 0.5 else 0.3
        AX_t = 0.9 if AX0 < 0.5 else 0.3
        return U_t, V_t, B_t, AS_t, AN_t, AX_t

    def _cache_init_reward_params(self):
        self._init_reward = (self.U.copy(), self.V.copy(), self.BIAS.copy(),
                             float(self.A_SIG), float(self.A_SIN), float(self.A_XPR))

    def _clear_target_reward_params(self):
        self._target_reward: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float]] = None

    @property
    def _target_reward(self):
        return getattr(self, "__target_reward", None)

    @_target_reward.setter
    def _target_reward(self, val):
        setattr(self, "__target_reward", val)

    def _sample_finite_contexts(self, N: int) -> np.ndarray:
        X = self.rng.normal(size=(N, self.d_c))
        norms = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
        return (X / norms).astype(np.float64)
