import numpy as np
import time
from tqdm import tqdm
from joblib import Parallel, delayed

def single_run(bandit_class, bandit_params, algorithms_classes, algorithms_params,
               T, change_prob, continuous, seed):
   
    np.random.seed(seed)

    def create_change_points():
        if change_prob is None or change_prob == 0:
            return [0, T]
        change_points = [0]
        for t in range(1, T):
            if np.random.random() <= change_prob:
                change_points.append(t)
        change_points.append(T)
        return change_points

    change_points = create_change_points()
   
    precompute_bandit = bandit_class(**bandit_params.copy())

    precomputed_theta_norms = np.zeros(T)
    precomputed_rewards = np.zeros((T, precompute_bandit.num_actions))
    precomputed_mean_rewards = np.zeros((T, precompute_bandit.num_actions))
    precomputed_best_rewards = np.zeros(T)

    previous_theta = precompute_bandit.theta.copy()

 
    for t in range(T):
        if not continuous and (change_prob is not None):
            if (t != 0) and t in change_points:
                precompute_bandit.abrupt_change()
        elif continuous:
            if precompute_bandit.target_theta is None:
                precompute_bandit.gradual_change(change_rate=t / T)
            else:
                precompute_bandit.gradual_change(change_rate=t / T)

        for a in range(precompute_bandit.num_actions):
            precomputed_rewards[t, a] = precompute_bandit.get_reward(a)
            precomputed_mean_rewards[t, a] = precompute_bandit.get_mean_reward(a)

        _, best_reward = precompute_bandit.get_best_arm()
        precomputed_best_rewards[t] = best_reward

        current_theta = precompute_bandit.theta.copy()
        diff = np.linalg.norm(current_theta - previous_theta)
        precomputed_theta_norms[t] = diff
        previous_theta = current_theta

    bandit_base = precompute_bandit
    algorithms = []
    algo_rewards = {}
    algo_times = {}
    detections = {}
    detection_delays = {} 

    PT = np.sum(precomputed_theta_norms) 

    for alg_class, alg_params in zip(algorithms_classes, algorithms_params):
        temp_alg = alg_class(**alg_params.copy())
        temp_alg.re_init()
        algorithms.append(temp_alg)
        name = temp_alg.__class__.__name__

        algo_rewards[name] = np.zeros(T)
        algo_times[name] = 0
        detections[name] = 0
        detection_delays[name] = []

    cumulative_optimal = np.cumsum(precomputed_best_rewards)

    for t in tqdm(range(T), desc="Single Run", leave=False):
        for algorithm in algorithms:
            name = algorithm.__class__.__name__

            start_time = time.time()
            try:
                action = algorithm.select_arm(bandit_base.arms, PT)
            except:pass
            try:
                action = algorithm.select_arm(bandit_base.arms, change_points)
            except:
                action = algorithm.select_arm(bandit_base.arms)

            reward = precomputed_rewards[t, action]
            mean_reward = precomputed_mean_rewards[t, action]

            algorithm.update(action, reward)

            algo_times[name] += (time.time() - start_time)

            algo_rewards[name][t] = mean_reward

    for algorithm in algorithms:
        name = algorithm.__class__.__name__
        detections[name] = len(algorithm.ChangePoints)

        for detect_time in algorithm.ChangePoints:
            idx = np.searchsorted(change_points, detect_time, side='right') - 1
            if idx >= 0:
                last_cp = change_points[idx]
                delay = detect_time - last_cp
                detection_delays[name].append(delay)
            else:
                detection_delays[name].append(detect_time)

    results = {}
    timings = {}
    for algorithm in algorithms:
        name = algorithm.__class__.__name__
        results[name] = cumulative_optimal - np.cumsum(algo_rewards[name])
        timings[name] = algo_times[name]

    return results, timings, detections, precomputed_theta_norms, detection_delays


class Environment(object):

    def __init__(self, bandit, algorithms, T, reward_bound,
                 change_prob=None, continuous=False):
        self.bandit = bandit
        self.algorithms = algorithms
        self.T = T
        self.change_prob = change_prob
        self.continuous = continuous
        self.reward_bound = reward_bound

    def run_experiment(self, n_mc, q=25):
       
        from joblib import Parallel, delayed

        cumulative_regrets = {
            alg.__class__.__name__: np.zeros((n_mc, self.T))
            for alg in self.algorithms
        }
        timings = {
            alg.__class__.__name__: 0.0
            for alg in self.algorithms
        }
        detections = {
            alg.__class__.__name__: 0.0
            for alg in self.algorithms
        }
        detection_delays_all = {
            alg.__class__.__name__: []
            for alg in self.algorithms
        }

        bandit_class = self.bandit.__class__
        bandit_params = self.bandit.init_params.copy()
        algorithms_classes = [alg.__class__ for alg in self.algorithms]
        algorithms_params = [alg.init_params.copy() for alg in self.algorithms]

        results_list = Parallel(n_jobs=-1)(
            delayed(single_run)(
                bandit_class, bandit_params,
                algorithms_classes, algorithms_params,
                self.T, self.change_prob, self.continuous,
                seed
            )
            for seed in tqdm(range(n_mc), desc="MC Trials")
        )

      
        for run_idx, (results, run_timings, run_detections, _, detection_delays) in enumerate(results_list):
            for alg_name in results:
                cumulative_regrets[alg_name][run_idx, :] = results[alg_name]
                timings[alg_name] += run_timings[alg_name]
                detections[alg_name] += run_detections[alg_name]
                detection_delays_all[alg_name].extend(detection_delays[alg_name])

        avg_regret = {
            alg_name: np.mean(cumulative_regrets[alg_name], axis=0)
            for alg_name in cumulative_regrets
        }
        std_regret = {
            alg_name: np.std(cumulative_regrets[alg_name], axis=0)
            for alg_name in cumulative_regrets
        }
        avg_timings = {
            alg_name: timings[alg_name] / n_mc
            for alg_name in timings
        }
        avg_detections = {
            alg_name: detections[alg_name] / n_mc
            for alg_name in detections
        }

        avg_detection_delays = {}
        for alg_name, delays_list in detection_delays_all.items():
            if len(delays_list) == 0:
                avg_detection_delays[alg_name] = float('inf')
            else:
                avg_detection_delays[alg_name] = np.mean(delays_list)

        return avg_regret, std_regret, avg_timings, avg_detections, avg_detection_delays
