from __future__ import annotations

import time
from typing import Dict, List, Tuple, Any
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed

def _arms_identity(K: int) -> List[Dict[str, str]]:
    return [{"aid": f"a{i}"} for i in range(K)]

def _ctx_to_vw_dict(x: np.ndarray) -> Dict[str, float]:
    x = np.asarray(x, float).reshape(-1)
    return {"bias": 1.0, **{f"x{i+1}": float(x[i]) for i in range(x.shape[0])}}

def _pack_reward_params_for_path_length(bandit) -> np.ndarray:
    parts = [
        bandit.U.ravel(),
        bandit.V.ravel(),
        bandit.BIAS.ravel(),
        np.array([bandit.A_SIG, bandit.A_SIN, bandit.A_XPR], dtype=float),
    ]
    return np.concatenate(parts, axis=0).astype(float)

def single_run_context(
    bandit_class,
    bandit_params: Dict[str, Any],
    algorithms_classes: List[type],
    algorithms_params: List[Dict[str, Any]],
    algorithm_names: List[str],
    T: int,
    change_prob: float | None,
    continuous: bool,
    seed: int,
):
    np.random.seed(seed)

    run_bandit_params = bandit_params.copy()
    if "seed" in run_bandit_params:
        run_bandit_params["seed"] = seed
    precompute_bandit = bandit_class(**run_bandit_params)

    K = precompute_bandit.num_actions
    try:
        bandit_arms = getattr(precompute_bandit, "arms", None)
        if bandit_arms is not None and len(bandit_arms) == K:
            arms_features = [np.asarray(a, dtype=float).ravel() for a in bandit_arms]
        else:
            arms_features = _arms_identity(K)
    except Exception:
        arms_features = _arms_identity(K)
    def create_change_points():
        if change_prob is None or change_prob == 0:
            return [0, T]
        cp = [0]
        for t in range(1, T):
            if np.random.random() <= change_prob:
                cp.append(t)
        cp.append(T)
        return cp

    change_points = create_change_points()
    cps_set = set(change_points[1:-1])

    precomputed_contexts = []
    precomputed_rewards = np.zeros((T, K), dtype=float)
    precomputed_mean_rewards = np.zeros((T, K), dtype=float)
    precomputed_best_rewards = np.zeros(T, dtype=float)

    precomputed_path_norms = np.zeros(T, dtype=float)
    prev_param_vec = _pack_reward_params_for_path_length(precompute_bandit)

    for t in range(T):
        if not continuous and (change_prob is not None) and t in cps_set:
            precompute_bandit.abrupt_change()
        elif continuous:
            precompute_bandit.gradual_change(change_rate=(t / max(1, T)))

        x_t = precompute_bandit.sample_context()
        precomputed_contexts.append(x_t.copy())

        for a in range(K):
            precomputed_mean_rewards[t, a] = precompute_bandit.expected_reward(a, x_t)
            precomputed_rewards[t, a] = precompute_bandit.get_reward(a)

        precomputed_best_rewards[t] = float(precomputed_mean_rewards[t].max())

        cur_param_vec = _pack_reward_params_for_path_length(precompute_bandit)
        precomputed_path_norms[t] = float(np.linalg.norm(cur_param_vec - prev_param_vec))
        prev_param_vec = cur_param_vec

    algorithms = []
    algo_rewards_expected = {}
    algo_times = {}
    detections = {}
    detection_delays = {}

    for (alg_class, alg_params, disp_name) in zip(algorithms_classes, algorithms_params, algorithm_names):
        alg = alg_class(**alg_params.copy())
        if hasattr(alg, "re_init"):
            alg.re_init()
        setattr(alg, "name", disp_name)
        algorithms.append(alg)

        algo_rewards_expected[disp_name] = np.zeros(T, dtype=float)
        algo_times[disp_name] = 0.0
        detections[disp_name] = 0
        detection_delays[disp_name] = []

    cumulative_optimal = np.cumsum(precomputed_best_rewards)

    for t in tqdm(range(T), desc="Single Run (CB)", leave=False):
        x_vec = precomputed_contexts[t]
        ctx_dict = _ctx_to_vw_dict(x_vec)
        for alg in algorithms:
            name = getattr(alg, "name", alg.__class__.__name__)
            start_time = time.time()

            a = None
            try:
                a = alg.select_arm(arms_features, context=ctx_dict)
            except TypeError:
                pass
            except Exception:
                pass
            if a is None:
                try:
                    a = alg.select_arm(None, context=x_vec)
                except TypeError:
                    pass
                except Exception:
                    pass
            if a is None:
                try:
                    a = alg.select_arm(arms_features, x_vec)
                except Exception:
                    pass
            if a is None:
                a = alg.select_arm(arms_features)

            a = int(a)
            reward_realized = precomputed_rewards[t, a]
            reward_expected = precomputed_mean_rewards[t, a]

            if hasattr(alg, "update"):
                alg.update(a, reward_realized)
            else:
                alg.update_statistics(a, reward_realized)

            algo_times[name] += (time.time() - start_time)
            algo_rewards_expected[name][t] = reward_expected

    for alg in algorithms:
        name = getattr(alg, "name", alg.__class__.__name__)
        if hasattr(alg, "ChangePoints") and isinstance(alg.ChangePoints, list):
            detections[name] = len(alg.ChangePoints)
            cps = change_points
            for detect_t in alg.ChangePoints:
                idx = np.searchsorted(cps, detect_t, side="right") - 1
                if idx >= 0:
                    detection_delays[name].append(int(detect_t - cps[idx]))
                else:
                    detection_delays[name].append(int(detect_t))
        else:
            detections[name] = 0
            detection_delays[name] = []

    results: Dict[str, np.ndarray] = {}
    timings: Dict[str, float] = {}
    for alg in algorithms:
        name = getattr(alg, "name", alg.__class__.__name__)
        results[name] = cumulative_optimal - np.cumsum(algo_rewards_expected[name])
        timings[name] = algo_times[name]

    return results, timings, detections, precomputed_path_norms, detection_delays

class EnvironmentContext:
    def __init__(self, bandit, algorithms, T: int, reward_bound: float,
                 change_prob: float | None = None, continuous: bool = False):
        self.bandit = bandit
        self.algorithms = algorithms
        self.T = int(T)
        self.change_prob = change_prob
        self.continuous = bool(continuous)
        self.reward_bound = float(reward_bound)

    def run_experiment(self, n_mc: int, n_jobs: int = -1):
        bandit_class = self.bandit.__class__
        bandit_params = self.bandit.init_params.copy()
        algorithms_classes = [alg.__class__ for alg in self.algorithms]
        algorithms_params = [getattr(alg, "init_params", {}).copy() for alg in self.algorithms]
        algorithm_names = [getattr(alg, "name", alg.__class__.__name__) for alg in self.algorithms]

        cumulative_regrets = {name: np.zeros((n_mc, self.T), dtype=float) for name in algorithm_names}
        timings = {name: 0.0 for name in algorithm_names}
        detections = {name: 0.0 for name in algorithm_names}
        detection_delays_all = {name: [] for name in algorithm_names}

        results_list = Parallel(n_jobs=n_jobs)(
            delayed(single_run_context)(
                bandit_class, bandit_params,
                algorithms_classes, algorithms_params, algorithm_names,
                self.T, self.change_prob, self.continuous,
                seed
            )
            for seed in tqdm(range(n_mc), desc="MC Trials (CB)")
        )

        for (results, run_timings, run_detections, _path, det_delays) in results_list:
            for name in results:
                run_idx = np.count_nonzero(cumulative_regrets[name][:, 0] != 0)
                if run_idx >= cumulative_regrets[name].shape[0]:
                    run_idx = cumulative_regrets[name].shape[0] - 1
                cumulative_regrets[name][run_idx, :] = results[name]
                timings[name] += run_timings[name]
                detections[name] += run_detections[name]
                detection_delays_all[name].extend(det_delays[name])

        avg_regret = {name: cumulative_regrets[name].mean(axis=0) for name in cumulative_regrets}
        std_regret = {name: cumulative_regrets[name].std(axis=0) for name in cumulative_regrets}
        avg_timings = {name: timings[name] / max(1, n_mc) for name in timings}
        avg_detections = {name: detections[name] / max(1, n_mc) for name in detections}

        avg_detection_delays = {}
        for name, delays in detection_delays_all.items():
            if len(delays) == 0:
                avg_detection_delays[name] = float("inf")
            else:
                avg_detection_delays[name] = float(np.mean(delays))

        return avg_regret, std_regret, avg_timings, avg_detections, avg_detection_delays
