from utils import max_
import math
import numpy as np


class UCB1:
    def __init__(self, c=1.0, n_actions=10):
        self.n_actions = n_actions
        self.c = c
        self.q_values = np.zeros(n_actions)
        self.n_actions_taken = np.zeros(n_actions)
        self.total_rewards = np.zeros(n_actions)
        self.reward_history = [[] for _ in range(n_actions)]  # To store rewards for variance calculation
        
        self.total_response_time = 0
        self.total_responses = 0
        
    def track_response_time(self, delay):
        """
        Tracks the response time for each action and calculates the average response time.
        """
        self.total_response_time += delay 
        self.total_responses += 1 
        self.average_response_time = self.total_response_time / self.total_responses if self.total_responses else 0
        return self.average_response_time
    
    def select_action(self):
        total_counts = np.sum(self.n_actions_taken)
        ucb_values = self.q_values + self.c * np.sqrt(2 * np.log(total_counts + 1) / (self.n_actions_taken + 1))
        return np.argmax(ucb_values)
    
    def update(self, action, reward, delay=0, env_prob=1.0):
        self.n_actions_taken[action] += 1 * env_prob
        self.total_rewards[action] += reward * env_prob
        self.reward_history[action].append(reward * env_prob)
        self.q_values[action] = self.total_rewards[action] / self.n_actions_taken[action]
        
        # Track response time for the action
        self.track_response_time(delay)
    
    def calculate_variance(self, action):
        if len(self.reward_history[action]) > 1:
            return np.var(self.reward_history[action])
        else:
            return 1
    
    
class UCB1_multiple_C(UCB1):
    def __init__(self, 
                 add_points_function_low,
                 add_points_function_mic,
                 add_points_function_high,
                 remove_points_function,
                 c_list, 
                 n_actions=10, 
                 c_timestep_change=100,
                 verbose=False, 
                 c_min=None, 
                 c_max=None):
        super().__init__(c=c_list[0], n_actions=n_actions)
        self.c_timestep_change = c_timestep_change
        self.num_actions = n_actions
        self.timestep = 0
        self.c_list = c_list
        self.c_ind = 0
        self.add_points_function_low = add_points_function_low
        self.add_points_function_mic = add_points_function_mic
        self.add_points_function_high = add_points_function_high
        self.remove_points_function = remove_points_function

        self.c_min = c_min
        self.c_max = c_max

        self.c = self.c_list[self.c_ind]
        self.c_history = []
        self.removed_points = []
        self.unique_c_history = []
        self.previous_c_values = []
        self.previous_c_scores = []
        self.verbose = verbose
        self.average_response_time = []

        self.total_response_time = 0.0
        self.total_responses = 0
        self.c_score = 0.0

        #self.BO_low = BayesianOptimization()
        #self.BO_mic = BayesianOptimization()
        #self.BO_high = BayesianOptimization()

        self.bo_selection_log = []

        self.total_prob_low = 0
        self.total_prob_mic = 0
        self.total_prob_high = 0

        self.total_obs_low = 0
        self.total_obs_mic = 0
        self.total_obs_high = 0

    def track_response_time(self, delay):
        self.total_response_time += delay
        self.total_responses += 1
        self.average_response_time = self.total_response_time / self.total_responses if self.total_responses else 0
        return self.average_response_time

    def update(self, chosen_act, reward, delay, p_e1, p_e2, p_e3, **kwargs):
        self.c_history.append(self.c)
        super().update(chosen_act, reward, delay, **kwargs)

        self.total_prob_low += p_e1
        self.total_prob_mic += p_e2
        self.total_prob_high += p_e3

        self.total_obs_low += 1
        self.total_obs_mic += 1
        self.total_obs_high += 1

        if self.timestep % self.c_timestep_change == 0:
            c_score = self.track_response_time(delay)

            if self.timestep < (self.c_timestep_change * len(self.c_list) - 1):
                self.previous_c_values.append(self.c)
                self.previous_c_scores.append(c_score)

                self.c_ind += 1
                self.c = self.c_list[self.c_ind % len(self.c_list)]
                self.unique_c_history.append(self.c)

                self.points_next_round = np.atleast_2d(self.previous_c_values).reshape(-1, 1)
                self.scores_next_round = np.atleast_2d(self.previous_c_scores).reshape(-1, 1)
            else:
                self.points_next_round = np.concatenate((self.points_next_round, np.atleast_2d([self.c]).reshape(-1, 1)))
                self.scores_next_round = np.concatenate((self.scores_next_round, np.atleast_2d([c_score]).reshape(-1, 1)))

                points_next_round, scores_next_round = self.remove_points_function(
                    self.points_next_round,
                    self.scores_next_round,
                    actions=self.num_actions,
                    c_value=self.c,
                    current_cycle=self.timestep,
                    current_iteration=self.timestep
                )

                self.points_next_round = points_next_round
                self.scores_next_round = scores_next_round

                # Update each Bayesian Optimization process
                new_points_low = self.add_points_function_low(
                    self.points_next_round,
                    self.scores_next_round,
                    actions=self.num_actions,
                    c_min=self.c_min,
                    c_max=self.c_max,
                    total_prob=self.total_prob_low,
                    total_obs=self.total_obs_low
                )
                self.bo_selection_log.append("Low")
                #print(f"Selected low Bayesian optimization: new c = {new_points_low[0]}")

                new_points_mic = self.add_points_function_mic(
                    self.points_next_round,
                    self.scores_next_round,
                    actions=self.num_actions,
                    c_min=self.c_min,
                    c_max=self.c_max,
                    total_prob=self.total_prob_mic,
                    total_obs=self.total_obs_mic
                )
                self.bo_selection_log.append("Micro")
                #print(f"Selected micro Bayesian optimization: new c = {new_points_mic[0]}")

                new_points_high = self.add_points_function_high(
                    self.points_next_round,
                    self.scores_next_round,
                    actions=self.num_actions,
                    c_min=self.c_min,
                    c_max=self.c_max,
                    total_prob=self.total_prob_high,
                    total_obs=self.total_obs_high
                )
                self.bo_selection_log.append("High")
                #print(f"Selected high Bayesian optimization: new c = {new_points_high[0]}")

                # Select the best c value
                self.c = min(new_points_low[0], new_points_mic[0], new_points_high[0])
                self.unique_c_history.append(self.c)
                #print("self.c==", self.c)

        self.timestep += 1
        return

    
    
class UCB1_():
    
    def __init__(self, c=1, actions=10):
        
        self.c = c
        self.actions = actions
        
        self.counts = [0 for col in range(self.actions)]
        
        self.values = [0.0 for col in range(self.actions)]
        
        self.action_total_reward = [0.0 for _ in range(self.actions)]
        self.action_avg_reward = [[] for action in range(self.actions)]
        
        return
    
    def max_(self,values):
        max_index = 0
        maxv = values[max_index]
        for i in range(len(values)):
            if values[i] > maxv:
                maxv = values[i]
                max_index = i
        return max_index
         

    def select_action(self):
        actions = len(self.counts)
        for action in range(actions):
            if self.counts[action] == 0:
                return action
    
        ucb_values = [0.0 for action in range(actions)]
        total_counts = sum(self.counts)
        for action in range(actions):
            bonus =   (math.sqrt((2 * math.log(total_counts)) / float(self.counts[action])))
            ucb_values[action] = self.values[action] + bonus
        return self.max_(ucb_values)

    def update(self, chosen_act, reward):
        self.counts[chosen_act] = self.counts[chosen_act] + 1
        n = self.counts[chosen_act]
        
#     # Update average/mean value/reward for chosen action
        value = self.values[chosen_act]
        new_value = ((n - 1) / float(n)) * value + (1 / float(n)) * reward
        #new_value2 = value + (1/n * (reward - value))
        self.values[chosen_act] = new_value
        
        self.action_total_reward[chosen_act] += reward
        for a in range(self.actions):
            if self.counts[a]:
                self.action_avg_reward[a].append(self.action_total_reward[a]/self.counts[a])
            else:
                self.action_avg_reward[a].append(0)
        return