import numpy as np
import time
from tqdm import tqdm
from joblib import Parallel, delayed
import copy


def _generate_change_points(T: int, change_prob: float) -> list[int]:
    
    if change_prob is None or change_prob == 0:
        return [0, T]

    cps = [0]
    rnd = np.random.random
    for t in range(1, T):
        if rnd() <= change_prob:
            cps.append(t)
    cps.append(T)
    return cps




def single_run(
    bandit_class,
    bandit_params: dict,
    algorithms_classes: list,
    algorithms_params: list,
    T: int,
    change_prob: float,
    continuous: bool,
    seed: int,
):
  

  
    rng = np.random.RandomState(seed)
    change_points = _generate_change_points(T, change_prob)

    pre_bandit = bandit_class(**bandit_params.copy())

    n_actions = pre_bandit.num_actions
    rewards = np.zeros((T, n_actions))  
    means = np.zeros((T, n_actions))    
    best_rewards = np.zeros(T)
    precomputed_rewards = np.zeros(T)

    previous_reward=np.array(pre_bandit.reward_means)
    for t in range(T):
        if not continuous and change_prob is not None and t in change_points and t != 0:
            pre_bandit.abrupt_change()
        elif continuous:
            pre_bandit.gradual_change()

        for a in range(n_actions):
            rewards[t, a] = pre_bandit.get_reward(a)
            means[t, a] = pre_bandit.get_mean_reward(a)

        _, best = pre_bandit.get_best_arm()
        best_rewards[t] = best
        cuurent_reward=np.array(pre_bandit.reward_means)
        precomputed_rewards[t] = np.max(np.abs(cuurent_reward-previous_reward))
        previous_reward=cuurent_reward.copy()

    bandit_arms = copy.deepcopy(pre_bandit.arms)
    PT= float(np.sum(precomputed_rewards))
   
    algorithms = []
    algo_r_means = {}
    algo_time = {}
    detections = {}
    detection_delays = {}

    for cls, params in zip(algorithms_classes, algorithms_params):
        alg = cls(**params.copy())
        alg.re_init()
        name = alg.__class__.__name__
        algorithms.append(alg)
        algo_r_means[name] = np.zeros(T)
        algo_time[name] = 0.0
        detections[name] = 0
        detection_delays[name] = []

    cum_opt = np.cumsum(best_rewards)

   
   
   
    for t in tqdm(range(T), leave=False):
        for alg in algorithms:
            name = alg.__class__.__name__

            tic = time.time()

            try:
                a = alg.select_arm(bandit_arms, PT)
            except:pass
            try:
                a = alg.select_arm(bandit_arms, change_points)
            except:
                a = alg.select_arm(bandit_arms)

            algo_time[name] += time.time() - tic

            algo_r_means[name][t] = means[t, a]
            alg.update(a, rewards[t, a])


    results = {}
    for alg in algorithms:
        name = alg.__class__.__name__
        results[name] = cum_opt - np.cumsum(algo_r_means[name])
        detections[name] = len(getattr(alg, "ChangePoints", []))

        for d in getattr(alg, "ChangePoints", []):
            idx = np.searchsorted(change_points, d, side="right") - 1
            last_cp = change_points[idx]
            detection_delays[name].append(d - last_cp)

    return results, algo_time, detections, best_rewards, detection_delays




class Environment:


    def __init__(
        self,
        bandit,
        algorithms: list,
        T: int,
        reward_bound,
        change_prob: float,
        continuous: bool = False,
    ) -> None:
        self.bandit = bandit
        self.algorithms = algorithms
        self.T = T
        self.change_prob = change_prob
        self.continuous = continuous
        self.reward_bound = reward_bound


    def run_experiment(self, n_mc: int, q: int = 25):


        alg_names = [alg.__class__.__name__ for alg in self.algorithms]

        cumulative_regrets = {
            n: np.zeros((n_mc, self.T)) for n in alg_names
        }
        timings = {n: 0.0 for n in alg_names}
        detections = {n: 0 for n in alg_names}
        det_delays_all = {n: [] for n in alg_names}

        bandit_cls = self.bandit.__class__
        bandit_params = self.bandit.init_params.copy()
        alg_cls = [alg.__class__ for alg in self.algorithms]
        alg_params = [alg.init_params.copy() for alg in self.algorithms]


        worker_out = Parallel(n_jobs=-1)(
            delayed(single_run)(
                bandit_cls,
                bandit_params,
                alg_cls,
                alg_params,
                self.T,
                self.change_prob,
                self.continuous,
                seed,
            )
            for seed in tqdm(range(n_mc), desc="MC Trials")
        )

        for i, (res, time_dict, dets, _, d_delays) in enumerate(worker_out):
            for n in alg_names:
                cumulative_regrets[n][i, :] = res[n]
                timings[n] += time_dict[n]
                detections[n] += dets[n]
                det_delays_all[n].extend(d_delays[n])


        avg_regret = {n: cumulative_regrets[n].mean(axis=0) for n in alg_names}
        std_regret = {n: cumulative_regrets[n].std(axis=0) for n in alg_names}
        avg_timing = {n: timings[n] / n_mc for n in alg_names}
        avg_dets = {n: detections[n] / n_mc for n in alg_names}
        avg_delays = {
            n: (float("inf") if len(dl) == 0 else np.mean(dl)) for n, dl in det_delays_all.items()
        }

        return avg_regret, std_regret, avg_timing, avg_dets, avg_delays
