import numpy as np

class EB_TC:
    def __init__(self, K, epsilon_0=0.1, beta=0.3):
        self.K = K
        self.epsilon_0 = epsilon_0
        self.beta = beta
        self.n = 0
        self.rewards = np.zeros(K, dtype=float)
        self.counts = np.zeros(K, dtype=int) # In count
        self.T_counts = np.zeros((K, K), dtype=int) # pair count
        self.N_counts = np.zeros((K, K), dtype=int) # pair count & In = challenger
        self.beta_bar = np.zeros((K, K), dtype=float)
        
        self.last_leader = None
        self.last_challenger = None
        
    def update(self, arm, reward):
        self.counts[arm] += 1
        self.rewards[arm] += reward
        self.n += 1
        
        if self.last_leader is not None and self.last_challenger is not None:
            if arm == self.last_challenger:
                self.N_counts[self.last_leader, self.last_challenger] += 1
            
    def select_next_arm(self):
        emp_means = np.zeros(self.K)
        
        for i in range(self.K):
            if self.counts[i] > 0:
                emp_means[i] = self.rewards[i] / self.counts[i]
        
        if self.n < self.K:
            return self.n # warm-up: force to pull each arm once
        
        Bn = np.argmax(emp_means)
        Cn = None
        min_val = float('inf')
        
        for i in range(self.K):
            if i == Bn or self.counts[i] == 0:
                continue
            val = (emp_means[Bn] - emp_means[i] + self.epsilon_0) / np.sqrt(1 / self.counts[Bn] + 1 / self.counts[i])
            if val < min_val:
                min_val = val
                Cn = i
        
        beta_val = self.beta 
        self.beta_bar[Bn, Cn] = (self.T_counts[Bn, Cn] * self.beta_bar[Bn, Cn] + beta_val) / (self.T_counts[Bn, Cn] + 1) 
        self.T_counts[Bn, Cn] += 1
        
        if self.N_counts[Bn, Cn] <= (1 - self.beta_bar[Bn, Cn]) * self.T_counts[Bn, Cn]:
            In = Cn
        else:
            In = Bn
            
        self.last_leader = Bn
        self.last_challenger = Cn
            
        return In