from __future__ import annotations

import time
from typing import Any, Callable, Dict, Iterable, List, Sequence, Optional, Union

import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm

from .simulations import BaseSimulation, SimulationData, CallPlan


def _call_select_arm(algorithm, plans: Sequence[CallPlan]):
    for plan in plans:
        try:
            return algorithm.select_arm(*plan.args, **plan.kwargs)
        except TypeError:
            continue
        except ValueError:
            continue
        except RuntimeError:
            continue
    if plans and plans[-1].args:
        return algorithm.select_arm(plans[-1].args[0])
    return algorithm.select_arm()


def _call_update(algorithm, arm: int, reward: float):
    if hasattr(algorithm, "update"):
        algorithm.update(int(arm), float(reward))
    else:
        algorithm.update_statistics(int(arm), float(reward))


class ExperimentEnvironment:
    def __init__(
        self,
        *,
        bandit,
        algorithm_factories: Sequence[Callable[[], Any]],
        algorithm_names: Optional[Sequence[str]] = None,
        T: int,
        change_prob: Optional[float],
        continuous: bool,
        reward_bound: float,
        simulation_cls: type[BaseSimulation],
        simulation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        self.bandit = bandit
        self._algorithm_factories = list(algorithm_factories)
        if algorithm_names:
            self.algorithm_names = list(algorithm_names)
        else:
            names = []
            for factory in self._algorithm_factories:
                alg = factory()
                names.append(getattr(alg, "name", alg.__class__.__name__))
            self.algorithm_names = names
        self.T = int(T)
        self.change_prob = change_prob
        self.continuous = bool(continuous)
        self.reward_bound = float(reward_bound)
        self.simulation_cls = simulation_cls
        self.simulation_kwargs = simulation_kwargs or {}

    def _simulation(self) -> BaseSimulation:
        bandit_params = getattr(self.bandit, "init_params", {}).copy()
        return self.simulation_cls(
            self.bandit.__class__,
            bandit_params,
            self.T,
            self.change_prob,
            self.continuous,
            **self.simulation_kwargs,
        )

    def run_experiment(self, n_mc: int, n_jobs: int = -1):
        alg_names = self.algorithm_names
        cumulative_regrets = {name: np.zeros((n_mc, self.T), dtype=float) for name in alg_names}
        timings = {name: 0.0 for name in alg_names}
        detections = {name: 0 for name in alg_names}
        delays_all = {name: [] for name in alg_names}

        simulation = self._simulation()

        def worker(seed: int):
            data = simulation.precompute(seed)
            return self._single_trial(data, self._algorithm_factories, alg_names)

        seeds = range(n_mc)
        if n_jobs == 1:
            results = [worker(seed) for seed in tqdm(seeds, desc="MC Trials", leave=False)]
        else:
            results = Parallel(n_jobs=n_jobs)(
                delayed(worker)(seed)
                for seed in tqdm(seeds, desc="MC Trials", leave=False)
            )

        for idx, (regrets, times, dets, delays) in enumerate(results):
            for name in alg_names:
                cumulative_regrets[name][idx, :] = regrets[name]
                timings[name] += times[name]
                detections[name] += dets[name]
                delays_all[name].extend(delays[name])

        avg_regret = {name: cumulative_regrets[name].mean(axis=0) for name in alg_names}
        std_regret = {name: cumulative_regrets[name].std(axis=0) for name in alg_names}
        avg_timing = {name: timings[name] / max(n_mc, 1) for name in alg_names}
        avg_detections = {name: detections[name] / max(n_mc, 1) for name in alg_names}
        avg_delays = {
            name: (float("inf") if not delays_all[name] else float(np.mean(delays_all[name])))
            for name in alg_names
        }

        return avg_regret, std_regret, avg_timing, avg_detections, avg_delays

    def _single_trial(
        self,
        data: SimulationData,
        algorithm_factories: Sequence[Callable[[], Any]],
        names: Sequence[str],
    ):
        algorithms = []
        for factory, name in zip(algorithm_factories, names):
            alg = factory()
            if hasattr(alg, "re_init"):
                alg.re_init()
            setattr(alg, "name", name)
            algorithms.append(alg)

        rewards = data.rewards
        means = data.mean_rewards
        best = data.best_rewards
        change_points = data.extras.get("change_points", [])
        T = rewards.shape[0]

        algo_means = {name: np.zeros(T, dtype=float) for name in names}
        algo_times = {name: 0.0 for name in names}

        for t in tqdm(range(T)):
            plans = data.select_plans[t]
            for alg, name in zip(algorithms, names):
                start = time.time()
                arm = _call_select_arm(alg, plans)
                algo_times[name] += time.time() - start

                reward = rewards[t, arm]
                algo_means[name][t] = means[t, arm]
                _call_update(alg, arm, reward)

        cumulative_optimal = np.cumsum(best)
        regrets = {name: cumulative_optimal - np.cumsum(algo_means[name]) for name in names}

        detections = {}
        detection_delays = {}
        for alg, name in zip(algorithms, names):
            cps = getattr(alg, "ChangePoints", [])
            detections[name] = len(cps)
            delays = []
            for cp in cps:
                idx = np.searchsorted(change_points, cp, side="right") - 1
                if idx >= 0:
                    delays.append(int(cp - change_points[idx]))
                else:
                    delays.append(int(cp))
            detection_delays[name] = delays

        return regrets, algo_times, detections, detection_delays


__all__ = ["ExperimentEnvironment"]
