import numpy as np
import random
import math

class RelToStoElm:
    def __init__(self, K, T, num_per_phase, T_S, means_offline, means_online, offline_data, V_matrix, delta, env, seed):
        """
        Initialize the Rel_to_Sto_Elm algorithm.

        Parameters:
        K: number of arms
        T: total number of rounds
        num_per_phase: total number of pulls per round
        T_S: number of comparisons for each pair of arms in offline data
        means_offline: offline mean (array or dictionary)
        means_online: online mean (array or dictionary)
        offline_data: offline data
        V_matrix: bias matrix
        delta: confidence parameter
        env: environment (SyntheticRelStoEnv or RealDataRelStoEnv)
        """
        self.K = K
        self.T = T
        self.num_per_phase = num_per_phase
        self.T_S = T_S
        self.means_offline = means_offline
        self.means_online = means_online
        self.offline_data = offline_data
        self.V_matrix = V_matrix
        self.delta = delta
        self.env = env
        self.A = set(range(K))  
        self.X = {i: [] for i in range(K)} 
        self.pull_counts = {i: 0 for i in range(K)} 
        if seed is not None:
            np.random.seed(seed) 

        if isinstance(means_online, dict):
            self.best_arm = max(means_online, key=means_online.get)
            self.optimal_mean = means_online[self.best_arm]
        else:
            self.best_arm = np.argmax(means_online)
            self.optimal_mean = means_online[self.best_arm]
        
        self.pull_rewards_list = []
        self.pull_regrets_list = []
        self.round_arm_counts = []

    def pull_arms(self, arm, pull_count):
        for _ in range(pull_count):
            reward = self.env.sample_reward(arm)
            self.X[arm].append(reward)
            self.pull_counts[arm] += 1
            if isinstance(self.means_online, dict):
                true_mean = list(self.means_online.values())[arm]
            else:
                true_mean = self.means_online[arm]
            regret = self.optimal_mean - true_mean
            self.pull_rewards_list.append(reward)
            self.pull_regrets_list.append(regret)
            self.round_arm_counts.append(len(self.A))

    def compute_ucb(self, i, j):
        m = self.pull_counts[i]
        n = self.pull_counts[j]
        if m == 0 or n == 0:
            return float('inf')
        avg_i = np.mean(self.X[i]) if self.X[i] else 0
        avg_j = np.mean(self.X[j]) if self.X[j] else 0
        ucb = 0.5 * (avg_i - avg_j + 1) + np.sqrt((m + n) / (m * n) * np.log(1.0 / self.delta))
        return ucb

    def compute_ucb_s(self, i, j):
        m = self.pull_counts[i]
        n = self.pull_counts[j]
        avg_i = np.mean(self.X[i]) if self.X[i] else 0
        avg_j = np.mean(self.X[j]) if self.X[j] else 0
        offline_pref_list = self.get_offline_preference(i, j)
        t_ij = self.T_S
        p_ij = sum(offline_pref_list) / self.T_S if self.T_S > 0 else 0.5
        mn_over = (m * n) / (m + n) if (m + n) > 0 else 0
        numerator = t_ij * p_ij + mn_over * 0.5 * (avg_i - avg_j + 1)
        denominator = t_ij + mn_over
        base_estimate = numerator / denominator if denominator > 0 else 1.0
        bonus = np.sqrt((1.0 / denominator) * np.log(1.0 / self.delta)) if denominator > 0 else 0
        bias_term = (t_ij / denominator) * self.V_matrix[i, j] if denominator > 0 else 0
        ucb_s = base_estimate + bonus + bias_term
        return ucb_s

    def compute_ucb_s_initial(self, i, j):
        offline_pref_list = self.get_offline_preference(i, j)
        t_ij = self.T_S
        p_ij = sum(offline_pref_list) / self.T_S if self.T_S > 0 else 0.5
        base_estimate = p_ij
        bonus = np.sqrt((1.0 / t_ij) * np.log(1.0 / self.delta)) if t_ij > 0 else 0
        bias_term = self.V_matrix[i, j]
        ucb_s = base_estimate + bonus + bias_term
        return ucb_s

    def get_offline_preference(self, i, j):
        preferences = []
        for arm_i, arm_j, outcome in self.offline_data:
            if arm_i == i and arm_j == j:
                preferences.append(outcome)
            elif arm_i == j and arm_j == i:
                preferences.append(1 - outcome)
        return preferences

    def initial_elimination(self):
        if self.T_S == 0 or not self.offline_data:
            return
        current_arms = list(self.A)
        to_remove = set()
        for i in current_arms:
            for j in current_arms:
                if i == j:
                    continue
                ucb_s = self.compute_ucb_s_initial(i, j)
                if ucb_s <= 0.5:
                    to_remove.add(i)
                    break
        for i in to_remove:
            if i in self.A:
                self.A.remove(i)

    def eliminate_arms(self, current_arms):
        to_remove = set()
        for i in current_arms:
            for j in current_arms:
                if i == j:
                    continue
                ucb = self.compute_ucb(i, j)
                if self.T_S == 0 or not self.offline_data:
                    if ucb <= 0.5:
                        to_remove.add(i)
                        break
                else:
                    ucb_s = self.compute_ucb_s(i, j)
                    if max(ucb, ucb_s) <= 0.5:
                        to_remove.add(i)
                        break
        for i in to_remove:
            if i in self.A:
                self.A.remove(i)

    def get_least_pulled_arm(self, current_arms):
        min_pulls = min(self.pull_counts[i] for i in current_arms)
        least_pulled_arms = [i for i in current_arms if self.pull_counts[i] == min_pulls]
        return random.choice(least_pulled_arms)

    def run(self):
        if self.num_per_phase <= self.K:
            raise ValueError(
                f"num_per_phase ({self.num_per_phase}) must be >= K + 1 ({self.K + 1}) "
                "to ensure all arms are pulled at least once in the first round for immediate elimination."
            )
        self.initial_elimination()
        for t in range(self.T):
            print(f"\n[Alg1] Round {t+1}")
            current_arms = list(self.A)
            if len(current_arms) == 0:
                print(f"[Alg1] Terminating at round {t+1} since no arms remain.")
                remaining_pulls = (self.T - t) * self.num_per_phase
                for _ in range(remaining_pulls):
                    self.pull_rewards_list.append(0.0)
                    self.pull_regrets_list.append(0.0)
                    self.round_arm_counts.append(0)
                break
            pulls_remaining = self.num_per_phase
            while pulls_remaining > 0 and len(current_arms) > 0:
                arm = self.get_least_pulled_arm(current_arms)
                self.pull_arms(arm, 1)
                self.eliminate_arms(current_arms)
                current_arms = list(self.A)
                pulls_remaining -= 1
        return (
            self.A,
            self.pull_rewards_list,
            self.pull_regrets_list,
            self.round_arm_counts
        )

class ETC:
    def __init__(self, K, T, m, means_online, env, seed):
        """
        ETC algorithm implementation.

        Parameters:
        K: number of arms
        T: total number of rounds
        m: number of explorations per arm
        means_online: online mean (array or dictionary, used to calculate regret)
        env: environment (SyntheticRelStoEnv or RealDataRelStoEnv)
        """
        self.K = K
        self.T = T
        self.m = m
        self.env = env
        if seed is not None:
            np.random.seed(seed)

        if isinstance(means_online, dict):
            self.means_online = means_online
            self.best_arm = max(means_online, key=means_online.get)
            self.optimal_mean = means_online[self.best_arm]
        else:
            self.means_online = means_online
            self.best_arm = np.argmax(means_online)
            self.optimal_mean = means_online[self.best_arm]
        self.X = {i: [] for i in range(K)}
        self.pull_counts = {i: 0 for i in range(K)}
        self.pull_rewards_list = []
        self.pull_regrets_list = []
        self.round_arm_counts = []

    def run(self):
        for i in range(self.K):
            for _ in range(self.m):
                reward = self.env.sample_reward(i)
                self.X[i].append(reward)
                self.pull_counts[i] += 1
                if isinstance(self.means_online, dict):
                    true_mean_i = list(self.means_online.values())[i]
                else:
                    true_mean_i = self.means_online[i]
                regret = self.optimal_mean - true_mean_i

                self.pull_rewards_list.append(reward)
                self.pull_regrets_list.append(regret)
                self.round_arm_counts.append(self.K)
        avg_rewards = [np.mean(self.X[i]) for i in range(self.K)]
        chosen_arm = np.argmax(avg_rewards)
        remaining_pulls = self.T - self.K * self.m
        for _ in range(remaining_pulls):
            reward = self.env.sample_reward(chosen_arm)
            self.X[chosen_arm].append(reward)
            self.pull_counts[chosen_arm] += 1
            if isinstance(self.means_online, dict):
                true_mean_i = list(self.means_online.values())[chosen_arm]
            else:
                true_mean_i = self.means_online[chosen_arm]
            regret = self.optimal_mean - true_mean_i
            self.pull_rewards_list.append(reward)
            self.pull_regrets_list.append(regret)
            self.round_arm_counts.append(self.K)
        return (
            set([chosen_arm]),
            self.pull_rewards_list,
            self.pull_regrets_list,
            self.round_arm_counts
        )

class UCB:
    def __init__(self, K, T, means_online, env, delta, seed):
        """
        UCB algorithm implementation.

        Parameters:
        K: number of arms
        T: total number of rounds
        means_online: online mean (array or dictionary, used to calculate regret)
        env: environment (SyntheticRelStoEnv or RealDataRelStoEnv)
        delta: confidence parameter
        """
        self.K = K
        self.T = T
        self.env = env
        self.delta = delta
        if seed is not None:
            np.random.seed(seed) 

        if isinstance(means_online, dict):
            self.means_online = means_online
            self.best_arm = max(means_online, key=means_online.get)
            self.optimal_mean = means_online[self.best_arm]
        else:
            self.means_online = means_online
            self.best_arm = np.argmax(means_online)
            self.optimal_mean = means_online[self.best_arm]
        self.X = {i: [] for i in range(K)}
        self.pull_counts = {i: 0 for i in range(K)}
        self.pull_rewards_list = []
        self.pull_regrets_list = []
        self.round_arm_counts = []

    def compute_ucb(self, i, t):
        """计算臂 i 的 UCB 值"""
        if self.pull_counts[i] == 0:
            return float('inf')
        avg_reward = np.mean(self.X[i])
        exploration_bonus = np.sqrt(2 * np.log(1/self.delta) / self.pull_counts[i])
        return avg_reward + exploration_bonus

    def run(self):
        """运行 UCB 算法"""
        for i in range(self.K):
            reward = self.env.sample_reward(i)
            self.X[i].append(reward)
            self.pull_counts[i] += 1
            if isinstance(self.means_online, dict):
                true_mean_i = list(self.means_online.values())[i]
            else:
                true_mean_i = self.means_online[i]
            regret = self.optimal_mean - true_mean_i
            self.pull_rewards_list.append(reward)
            self.pull_regrets_list.append(regret)
            self.round_arm_counts.append(self.K)
        for t in range(self.K + 1, self.T + 1):
            ucb_values = [self.compute_ucb(i, t) for i in range(self.K)]
            chosen_arm = np.argmax(ucb_values)
            reward = self.env.sample_reward(chosen_arm)
            self.X[chosen_arm].append(reward)
            self.pull_counts[chosen_arm] += 1
            if isinstance(self.means_online, dict):
                true_mean_i = list(self.means_online.values())[chosen_arm]
            else:
                true_mean_i = self.means_online[chosen_arm]
            regret = self.optimal_mean - true_mean_i
            self.pull_rewards_list.append(reward)
            self.pull_regrets_list.append(regret)
            self.round_arm_counts.append(self.K)
        avg_rewards = [np.mean(self.X[i]) if self.X[i] else -float('inf') for i in range(self.K)]
        final_arm = np.argmax(avg_rewards)
        return (
            set([final_arm]),
            self.pull_rewards_list,
            self.pull_regrets_list,
            self.round_arm_counts
        )

class ThompsonSampling:
    def __init__(self, K, T, means_online, env, prior_mean, init_pulls, seed):
        """
        Implement Algorithm 2: Thompson Sampling using Gaussian priors in the paper.

        Parameters:
        K: number of arms
        T: total number of rounds
        means_online: online mean (array or dictionary, used to calculate regret)
        env: environment (SyntheticRelStoEnv or RealDataRelStoEnv)
        prior_mean: prior mean (default 0.0, initialized to 0 in the paper)
        init_pulls: number of initialization pulls for each arm (default 0, no mandatory initialization pulls in the paper)
        """
        self.K = K
        self.T = T
        self.env = env
        self.prior_mean = prior_mean
        self.init_pulls = init_pulls
        if seed is not None:
            np.random.seed(seed)  

        if isinstance(means_online, dict):
            self.means_online = means_online
            self.best_arm = max(means_online, key=means_online.get)
            self.optimal_mean = means_online[self.best_arm]
        else:
            self.means_online = means_online
            self.best_arm = np.argmax(means_online)
            self.optimal_mean = means_online[self.best_arm]
        
        self.X = {i: [] for i in range(K)}  
        self.pull_counts = {i: 0 for i in range(K)}  
        self.posterior_means = {i: self.prior_mean for i in range(K)}
        self.pull_rewards_list = []  
        self.pull_regrets_list = []  
        self.round_arm_counts = []  

    def sample_posterior(self, arm):
        """
        Sample from the posterior distribution of arm arm, based on a Gaussian distribution \mathcal{N}(\hat{\mu}_i, 1/(k_i + 1)).

        Parameters:
        arm: arm index
        Returns:
        The sampled mean \theta_i(t)
        """
        mean = self.posterior_means[arm]
        variance = 1.0 / (self.pull_counts[arm] + 1)  
        return np.random.normal(loc=mean, scale=np.sqrt(variance))

    def update_posterior(self, arm, reward):
        k = self.pull_counts[arm]
        current_mean = self.posterior_means[arm]
        self.posterior_means[arm] = (current_mean * k + reward) / (k + 1)
        self.pull_counts[arm] += 1

    def run(self):
        """
        Run the Thompson Sampling algorithm (Algorithm 2).

        Returns:
        (final_arm_set, pull_rewards_list, pull_regrets_list, round_arm_counts)
        - final_arm_set: the arm that is finally selected (based on the arm with the highest average reward)
        - pull_rewards_list: the reward list for each round
        - pull_regrets_list: the regret list for each round
        - round_arm_counts: the number of arms in each round (consistent with the input, fixed to K)
        """
        for i in range(self.K):
            for _ in range(self.init_pulls):
                reward = self.env.sample_reward(i)
                self.X[i].append(reward)
                self.update_posterior(i, reward)
                self.pull_counts[i] += 1
                if isinstance(self.means_online, dict):
                    true_mean_i = list(self.means_online.values())[i]
                else:
                    true_mean_i = self.means_online[i]
                print(true_mean_i)
                print(self.optimal_mean)
                regret = self.optimal_mean - true_mean_i

                self.pull_rewards_list.append(reward)
                self.pull_regrets_list.append(regret)
                self.round_arm_counts.append(self.K)

        for t in range(self.K * self.init_pulls + 1, self.T + 1):
            sampled_means = [self.sample_posterior(i) for i in range(self.K)]
            chosen_arm = np.argmax(sampled_means)
            reward = self.env.sample_reward(chosen_arm)
            self.X[chosen_arm].append(reward)
            self.update_posterior(chosen_arm, reward)
            if isinstance(self.means_online, dict):
                true_mean_i = list(self.means_online.values())[chosen_arm]
            else:
                true_mean_i = self.means_online[chosen_arm]
            regret = self.optimal_mean - true_mean_i
            print(f"t={t}, sampled_means={sampled_means}, chosen_arm={chosen_arm}")
            self.pull_rewards_list.append(reward)
            self.pull_regrets_list.append(regret)
            self.round_arm_counts.append(self.K)

        avg_rewards = [np.mean(self.X[i]) if self.X[i] else -float('inf') for i in range(self.K)]
        final_arm = np.argmax(avg_rewards)

        return (
            set([final_arm]),
            self.pull_rewards_list,
            self.pull_regrets_list,
            self.round_arm_counts
        )