import numpy as np

class Sto_to_Rel_UCB:
    def __init__(self, K, delta, V, use_offline, seed):
        self.K = K
        self.delta = delta
        self.V = V
        self.use_offline = use_offline
        self.sum_offline = np.zeros(K)
        self.count_offline = np.zeros(K)
        self.pos_count_on = np.zeros((K, K))
        self.total_count_on = np.zeros((K, K))
        if seed is not None:
            np.random.seed(seed)  # Set global random seed

    def fit_offline_data(self, offline_data):
        if not self.use_offline:
            return
        for (arm, r) in offline_data:
            self.sum_offline[arm] += r
            self.count_offline[arm] += 1

    def update_online(self, i, j, outcome):
        self.total_count_on[i, j] += 1
        self.total_count_on[j, i] += 1
        if outcome == 1:
            self.pos_count_on[i, j] += 1
        else:
            self.pos_count_on[j, i] += 1

    def calc_UCB_off(self, i, j):
        if not self.use_offline:
            return self.calc_UCB_on(i, j)
        T_ij = self.total_count_on[i, j]
        pos_on = self.pos_count_on[i, j]
        m = self.count_offline[i]
        n = self.count_offline[j]
        sum_i = self.sum_offline[i]
        sum_j = self.sum_offline[j]
        if (m < 1e-9) or (n < 1e-9):
            if T_ij < 1e-9:
                return np.inf
            base = pos_on / T_ij
            ci = np.sqrt(np.log(1/self.delta)/(2*T_ij))
            return base + ci
        alpha = (m * n) / (m + n)
        diff_off = (sum_i/m) - (sum_j/n)
        part_off = 1.0 / (1.0 + np.exp(-diff_off))
        numerator = pos_on + alpha * part_off
        denominator = T_ij + alpha
        if denominator < 1e-9:
            return np.inf
        ci = np.sqrt((0.5/denominator) * np.log(1.0/self.delta))
        V_val = self.V[i, j] if self.V is not None else 0.0
        base_val = numerator / denominator
        off_val = base_val + ci + (alpha/denominator)*V_val
        return off_val

    def calc_UCB_on(self, i, j):
        T_ij = self.total_count_on[i, j]
        pos_on = self.pos_count_on[i, j]
        if T_ij < 1e-9:
            return np.inf
        p_hat = pos_on / T_ij
        ci = np.sqrt(np.log(1/self.delta)/(2*T_ij))
        return p_hat + ci   

    def select_pair(self):
        A = range(self.K)
        C_on = [i for i in A if all(self.calc_UCB_on(i, j) >= 0.5 for j in A)]
        if not self.use_offline:
            C_t = C_on
        else:
            C_off = [i for i in A if all(self.calc_UCB_off(i, j) >= 0.5 for j in A)]
            C_t = list(set(C_on).intersection(set(C_off)))

        if len(C_t) == 0:
            search_space = [(i, j) for i in A for j in A]  # A × A
        else:
            search_space = [(i, j) for i in C_t for j in C_t]  # C_t × C_t

        best_val = -np.inf
        candidates = []
        for i, j in search_space:
            val_on = self.calc_UCB_on(i, j)
            if self.use_offline:
                val_off = self.calc_UCB_off(i, j)
                val = min(val_off, val_on)
            else:
                val = val_on
            if val > best_val:
                best_val = val
                candidates = [(i, j)]
            elif val == best_val:
                candidates.append((i, j))

        chosen_pair = np.random.choice(len(candidates))  
        return candidates[chosen_pair]  

    def run_online(self, env, T, num_relative_draws, true_mu_on):
        regret_history = []
        reward_history = []
        for _ in range(T):
            c,d = self.select_pair()
            for _ in range(num_relative_draws):
                outcome = env.play_duel(d, c)
                self.update_online(d, c, outcome)
            r_c = env.sample_reward(c)
            r_d = env.sample_reward(d)
            reward_history.append((r_c + r_d) / 2.0)
            if true_mu_on is not None:
                mu_star = np.max(true_mu_on)
                # regret_t = (1 / (1 + np.exp(-(mu_star - true_mu_on[c]))) + 1 / (1 + np.exp(-(mu_star - true_mu_on[d])))) / 2 - 0.5
                regret_t = mu_star -(true_mu_on[c]+true_mu_on[d])/2
                regret_history.append(regret_t)
            else:
                regret_history.append(None)
        return regret_history, reward_history

    def get_best_arm(self):
        win_counts = np.zeros(self.K)
        for i in range(self.K):
            for j in range(self.K):
                if i != j:
                    total_comparisons = self.total_count_on[i, j]
                    if total_comparisons > 0:
                        win_rate = self.pos_count_on[i, j] / total_comparisons
                        if win_rate > 0.5:
                            win_counts[i] += 1
        return np.argmax(win_counts)
    
class RUCB:
    def __init__(self, K, alpha, delta, seed):
        self.K = K
        self.alpha = alpha
        self.T = np.inf
        self.W = np.zeros((K, K))
        self.t = 0
        self.delta = delta
        if seed is not None:
            np.random.seed(seed)  # Set global random seed

    def select_arms(self):
        self.t += 1
        W_sum = self.W + self.W.T
        U = np.zeros((self.K, self.K))
        for i in range(self.K):
            for j in range(self.K):
                if W_sum[i, j] > 0:
                    log_t = np.log(max(self.t, 1))
                    U[i, j] = self.W[i, j] / W_sum[i, j] + np.sqrt(np.log(1/self.delta) / (2*W_sum[i, j]))
                else:
                    U[i, j] = self.W[i, j] / 1 + np.sqrt(np.log(1/self.delta) / (2 * 1))
        np.fill_diagonal(U, 0.5)
        potential_champions = [i for i in range(self.K) if np.all(U[i, :] >= 0.5)]
        c = np.random.choice(potential_champions) if potential_champions else np.random.randint(self.K)
        max_ucb = np.max(U[:, c])
        candidates_d = np.where(U[:, c] == max_ucb)[0]
        d = np.random.choice(candidates_d)
        return c, d

    def update(self, c, d, outcome):
        if outcome == 1:
            self.W[c, d] += 1
        else:
            self.W[d, c] += 1

    def run_online(self, env, T, true_mu_on):
        regret_history = []
        reward_history = []
        optimal_reward = np.max(true_mu_on)
        for _ in range(T):
            c, d = self.select_arms()
            outcome = env.play_duel(c, d)
            self.update(c, d, outcome)
            winner_arm = c if outcome == 1 else d
            reward = 1 if winner_arm == np.argmax(true_mu_on) else 0
            reward_history.append(reward)
            # regret = (1 / (1 + np.exp(-(optimal_reward - true_mu_on[c]))) + 1 / (1 + np.exp(-(optimal_reward - true_mu_on[d])))) / 2 - 0.5
            regret = optimal_reward -(true_mu_on[c]+true_mu_on[d])/2
            regret_history.append(regret)
        return regret_history, reward_history

    def get_best_arm(self):
        win_counts = np.zeros(self.K)
        for i in range(self.K):
            for j in range(self.K):
                if i != j:
                    total_comparisons = self.W[i, j] + self.W[j, i]
                    if total_comparisons > 0:
                        win_rate = self.W[i, j] / total_comparisons
                        if win_rate > 0.5:
                            win_counts[i] += 1
        return np.argmax(win_counts)

class InterleavedFilter2:
    def __init__(self, K, T, delta, seed):
        self.K = K
        self.T = T
        self.delta = delta
        self.total_comparisons = 0
        self.hat_b = None
        if seed is not None:
            np.random.seed(seed)  # Set global random seed

    def run(self, env, true_mu_on):
        W = set(range(self.K))
        hat_b = np.random.choice(list(W))
        W.remove(hat_b)
        P_hat = {b: 0.5 for b in range(self.K)}
        counts = {b: 0 for b in range(self.K)}
        regret_history = []
        optimal_reward = np.max(true_mu_on)

        while W and self.total_comparisons < self.T:
            for b in list(W):
                if self.total_comparisons >= self.T:
                    break
                self.total_comparisons += 1
                outcome = env.play_duel(hat_b, b)
                counts[b] += 1
                P_hat[b] = (P_hat[b] * (counts[b] - 1) + outcome) / counts[b]
                t = counts[b]
                c_t = np.sqrt(np.log(1 / self.delta) / t) if t > 0 else np.inf
                C_lower = P_hat[b] - c_t
                C_upper = P_hat[b] + c_t
                # regret = (1 / (1 + np.exp(-(optimal_reward - true_mu_on[hat_b]))) + 1 / (1 + np.exp(-(optimal_reward - true_mu_on[hat_b])))) / 2 - 0.5
                regret = optimal_reward -(true_mu_on[hat_b]+true_mu_on[hat_b])/2
                regret_history.append(regret)
            
            removed = True
            while removed and W and self.total_comparisons < self.T:
                removed = False
                for b in list(W):
                    t = counts[b]
                    c_t = np.sqrt(np.log(1 / self.delta) / t) if t > 0 else np.inf
                    C_lower = P_hat[b] - c_t
                    C_upper = P_hat[b] + c_t
                    if P_hat[b] > 0.5 and 0.5 < C_lower:
                        W.remove(b)
                        removed = True
            
            if self.total_comparisons >= self.T:
                break
            
            new_hat_b = None
            for b in list(W):
                t = counts[b]
                c_t = np.sqrt(np.log(1 / self.delta) / t) if t > 0 else np.inf
                C_lower = P_hat[b] - c_t
                C_upper = P_hat[b] + c_t
                if P_hat[b] < 0.5 and 0.5 > C_upper:
                    new_hat_b = b
                    break
            
            if new_hat_b is not None:
                hat_b = new_hat_b
                W.remove(hat_b)
                P_hat = {b: 0.5 for b in range(self.K)}
                counts = {b: 0 for b in range(self.K)}
        
        while len(regret_history) < self.T:
            # regret = (1 / (1 + np.exp(-(optimal_reward - true_mu_on[hat_b]))) + 1 / (1 + np.exp(-(optimal_reward - true_mu_on[hat_b])))) / 2 - 0.5
            regret = optimal_reward -(true_mu_on[hat_b]+true_mu_on[hat_b])/2
            regret_history.append(regret)
        
        self.hat_b = hat_b
        return regret_history

    def get_best_arm(self):
        return self.hat_b