from utils import max_
import numpy as np
class BayesianUCB:
    def __init__(self, c=1.0, n_actions=10, mu_0=0, lambda_0=1, alpha_0=1, beta_0=1):
        self.n_actions = n_actions
        self.c = c
        
        # Hyperparameters for the Normal-Gamma prior
        self.mu_0 = mu_0
        self.lambda_0 = lambda_0
        self.alpha_0 = alpha_0
        self.beta_0 = beta_0
        
        self.counts = np.zeros(n_actions)
        self.sum_rewards = np.zeros(n_actions)
        self.sum_squared_rewards = np.zeros(n_actions)
        self.q_values = np.zeros(n_actions)  # To store the estimated values for each action
        
        self.total_response_time = 0
        self.total_responses = 0

    def track_response_time(self, delay):
        """
        Tracks the response time for each action and calculates the average response time.
        """
        self.total_response_time += delay
        self.total_responses += 1
        self.average_response_time = self.total_response_time / self.total_responses if self.total_responses else 0
        return self.average_response_time
    
    def select_action(self):
        samples = np.zeros(self.n_actions)
        for a in range(self.n_actions):
            if self.counts[a] > 0:
                # Posterior parameters for Normal-Gamma distribution
                lambda_n = self.lambda_0 + self.counts[a]
                alpha_n = self.alpha_0 + self.counts[a] / 2
                beta_n = self.beta_0 + 0.5 * (self.sum_squared_rewards[a] - (self.sum_rewards[a] ** 2) / self.counts[a])
                mu_n = (self.lambda_0 * self.mu_0 + self.sum_rewards[a]) / lambda_n
                
                # Sample from the posterior distribution
                tau_n = np.random.gamma(alpha_n, 1 / beta_n)
                sigma_n = 1 / np.sqrt(tau_n * lambda_n)
                mu_sample = np.random.normal(mu_n, sigma_n)
            else:
                # If no samples have been taken for this action, use prior
                mu_sample = np.random.normal(self.mu_0, 1 / np.sqrt(self.lambda_0 * self.beta_0 / self.alpha_0))
            
            samples[a] = mu_sample + self.c * np.sqrt(2 * np.log(np.sum(self.counts) + 1) / (self.counts[a] + 1))
        
        # Store the sampled values in q_values
        self.q_values = samples
        return np.argmax(samples)

    def update(self, action, reward, delay=0, env_prob=1.0):
        self.counts[action] += 1 * env_prob
        self.sum_rewards[action] += reward * env_prob
        self.sum_squared_rewards[action] += (reward * env_prob) ** 2
        
        # Update q_values with the new average reward
        self.q_values[action] = self.sum_rewards[action] / self.counts[action] if self.counts[action] > 0 else 0
        
        # Track response time for the action
        self.track_response_time(delay)
    
    def calculate_variance(self, action):
        if self.counts[action] > 1:
            mean = self.sum_rewards[action] / self.counts[action]
            mean_sq = self.sum_squared_rewards[action] / self.counts[action]
            return mean_sq - mean ** 2
        else:
            return 1