import numpy as np
import math

class ImprovedUCB:
    def __init__(self, num_algorithms, alpha_diversity=1.0, alpha_recovery=1.0, lambda_prior_initial=1.0):
   
        self.num_algorithms = num_algorithms
        self.alpha_diversity = alpha_diversity
        self.alpha_recovery = alpha_recovery
        self.lambda_prior_initial = lambda_prior_initial 
        self.R = np.zeros(num_algorithms)  
        self.N = np.zeros(num_algorithms)  
        self.Srecovery = np.zeros(num_algorithms) 
        self.T = 0  
        self.llm_priors = np.zeros(num_algorithms) 
    def update_llm_priors(self, priors):
       
        assert len(priors) == self.num_algorithms
        self.llm_priors = np.array(priors)

    def calculate_ucb(self, algorithm_index):
        
        a = algorithm_index
        average_reward = self.R[a] / (self.N[a] + 1e-6)  
        exploration = math.sqrt(2 * math.log(self.T + 1) / (self.N[a] + 1e-6))
        lambda_prior = self.lambda_prior_initial / (1 + self.T)
        prior_term = lambda_prior * self.llm_priors[a]
        diversity_bonus = self.alpha_diversity / (self.N[a] + 1)
        recovery_term = self.alpha_recovery * self.Srecovery[a]

        ucb_value = average_reward + exploration + prior_term + diversity_bonus + recovery_term
        return ucb_value

    def choose_algorithm(self):
       
        ucb_values = [self.calculate_ucb(i) for i in range(self.num_algorithms)]
      
        selected_algorithm = np.argmax(ucb_values)
        return selected_algorithm

    def update(self, algorithm_index, reward, recovery_score=0.0):
        
        a = algorithm_index
        self.T += 1
        self.N[a] += 1
        self.R[a] += reward
        self.Srecovery[a] = recovery_score

if __name__ == '__main__':
    num_algorithms = 5
    ucb = ImprovedUCB(num_algorithms=num_algorithms)

   
    initial_priors = [0.8, 0.2, 0.5, 0.7, 0.1]
    ucb.update_llm_priors(initial_priors)

    num_trials = 100
    for t in range(num_trials):
        
        chosen_algorithm = ucb.choose_algorithm()
        print(f"Trial {t+1}: Chosen algorithm = {chosen_algorithm}")

      
        if chosen_algorithm == 0:
            reward = np.random.normal(1.0, 0.5)  
            recovery_score = 0.0

        elif chosen_algorithm == 1:
            reward = np.random.normal(0.5, 0.3)
            recovery_score = 0.0
        elif chosen_algorithm == 2:
            reward = np.random.normal(0.7, 0.4)
            recovery_score = 0.0

        elif chosen_algorithm == 3:
            reward = np.random.normal(0.9, 0.2)
            recovery_score = 0.0

        else:
            reward = np.random.normal(0.3, 0.1)
            recovery_score = 0.0


        if np.random.rand() < 0.2:  
            reward = -0.5 

            if np.random.rand() < 0.5:
                recovery_score = 0.8 
            else:
                recovery_score = 0.0
       
        ucb.update(chosen_algorithm, reward, recovery_score)

        if (t + 1) % 10 == 0:
             new_priors = [np.random.rand() for _ in range(num_algorithms)] 
             ucb.update_llm_priors(new_priors)
             print(f"Updated LLM priors: {new_priors}")

    print("\nAlgorithm selection counts:")
    print(ucb.N)
    print("\nAlgorithm average rewards:")
    print(ucb.R / (ucb.N + 1e-6))