from workload import workloads
from environments import Env
import numpy as np

def pre_train_agent(agent, env, n_steps, env_prob=1.0):
    for i in range(n_steps):
        action = agent.select_action()
        delay = env.sample(action)
        reward = -delay
        agent.update(action, reward, delay, env_prob=env_prob)
        env.step()
    return agent

def pre_train_agent_BO(agent, env, n_steps, p_e1=1, p_e2=1, p_e3 =1):
    for i in range(n_steps):
        action = agent.select_action()
        delay = env.sample(action)
        reward = -delay
        agent.update(action, reward, delay, p_e1=p_e1, p_e2=p_e2, p_e3=p_e3)
        env.step()
    return agent


def do_run(env, select_action_class, n_steps, pre_train_steps=100, n_actions=40):
    
    
    n_env = 3
    mus_all, sigmas_all = workloads(n_env, n_actions)
    
    low_workload_env = Env(mus=mus_all[0], sigmas=sigmas_all[0])  
    mic_workload_env = Env(mus=mus_all[1], sigmas=sigmas_all[1])
    high_workload_env = Env(mus=mus_all[2], sigmas=sigmas_all[2]) 
    
    select_action_class.agent_low_workload = pre_train_agent(select_action_class.agent_low_workload, low_workload_env, pre_train_steps)
    select_action_class.agent_mic_workload = pre_train_agent(select_action_class.agent_mic_workload, mic_workload_env, pre_train_steps)
    select_action_class.agent_high_workload = pre_train_agent(select_action_class.agent_high_workload, high_workload_env, pre_train_steps)

    selected_actions, rewards, res, average_response_time, cum_response_time, env_keys, best_chosen, correct_identifications = [], [], [], [], [], [], [], []
    
    reward_log = []
    best_chosen_fraction = []
    correct_identifications_count = 0
    correct_identifications.append(correct_identifications_count / (1))
    #
    regret_list, cum_regret, avg_cum_regret = [], [], []
    cumulative_regret = 0.0
    
    for i in range(n_steps):
        action_low, action_mic, action_high = select_action_class.select_action()
        
        actions = [action_low, action_mic, action_high]
        probabilities = [select_action_class.p_e1, select_action_class.p_e2, select_action_class.p_e3]
        action = np.random.choice(actions, p=probabilities)
        
        best_action = env.best_action()
        best_action_chosen = action == best_action
        best_chosen.append(best_action_chosen)
        
        
        best_chosen_fraction.append (sum(best_chosen) / len(best_chosen))

       
        
        actual_env = env.current_env()
        env_keys.append(actual_env)
        
        # Sample response times: one for the chosen action and one for the best action.
        chosen_delay = env.sample(action)
        best_delay = env.sample(best_action)
        
        
        inst_regret = chosen_delay - best_delay
        regret_list.append(inst_regret)
        
        cumulative_regret += inst_regret
        cum_regret.append(cumulative_regret)
        
        
        avg_reg = cumulative_regret / (i + 1)
        avg_cum_regret.append(avg_reg)
        
        
        environments = [0, 1, 2]
        predicted_env = np.random.choice(environments, p=probabilities)   
        if actual_env == predicted_env:
            correct_identifications_count += 1
        
        correct_identifications.append(correct_identifications_count / (i + 2))
        
        delay = env.sample(action)
        
        
        reward = -delay
        select_action_class.track_response_time(delay)
        
        selected_actions.append(action)
        rewards.append(reward)
        res.append(delay)
        average_response_time.append(sum(res) / len(res))
        
        #average_response_time.append((sum(res) / len(res)) / max(env.mus))
        
        cum_response_time.append(sum(res))
        reward_log.append(reward)
        
        select_action_class.update_q_values(action, reward, delay)
        select_action_class.update_environment_probabilities(action, reward)
        env.step()

        
        
    return {
        "rewards": rewards,
        "cum_response_time": cum_response_time,
        "average_response_time": average_response_time,
        "res": res,
        "best_chosen": best_chosen_fraction,
        "selected_actions": selected_actions,
        "env_keys": env_keys,
        "reward_log": reward_log,
        "correct_identifications": correct_identifications,
        "average_cumulative_regret": avg_cum_regret
    }
    
    

def do_run_BO(env, select_action_class, n_steps, pre_train_steps=50, n_actions=40):
    
    n_env = 3
    mus_all, sigmas_all = workloads(n_env, n_actions)
    
    low_workload_env = Env(mus=mus_all[0], sigmas=sigmas_all[0])  
    mic_workload_env = Env(mus=mus_all[1], sigmas=sigmas_all[1])
    high_workload_env = Env(mus=mus_all[2], sigmas=sigmas_all[2]) 
    
    
    select_action_class.agent_low_workload = pre_train_agent_BO(select_action_class.agent_low_workload, low_workload_env, pre_train_steps)
    select_action_class.agent_mic_workload = pre_train_agent_BO(select_action_class.agent_mic_workload, mic_workload_env, pre_train_steps)
    select_action_class.agent_high_workload = pre_train_agent_BO(select_action_class.agent_high_workload, high_workload_env, pre_train_steps)

    # Main Simulation Phase
    selected_actions, rewards, res, Nom_norm, cum_response_time, cum_cost, average_response_time, env_keys, best_chosen, correct_identifications = [], [], [], [], [], [], [], [], [], []
    
    reward_log = []
    
    best_chosen_fraction = []
    correct_identifications_count = 0
    
    # Track initial probabilities for correct identification
    correct_identifications.append(correct_identifications_count / (1))  # Start with initial probability
    
    #
    regret_list, cum_regret, avg_cum_regret = [], [], []
    cumulative_regret = 0.0
    
    for i in range(n_steps):
        action_low,action_mic, action_high = select_action_class.select_action()
        
        # Sample an action based on the estimated probabilities
        actions = [action_low, action_mic, action_high]
        probabilities = [select_action_class.p_e1, select_action_class.p_e2, select_action_class.p_e3]
        action = np.random.choice(actions, p=probabilities)
        
        best_action = env.best_action()
        best_action_chosen = action == best_action
        best_chosen.append(best_action_chosen)
        
        
        fraction_chosen = np.mean(best_chosen)
        best_chosen_fraction.append(fraction_chosen)
        
        actual_env = env.current_env()
        env_keys.append(actual_env)
        
        # Sample response times: one for the chosen action and one for the best action.
        chosen_delay = env.sample(action)
        best_delay = env.sample(best_action)
        
        
        inst_regret = chosen_delay - best_delay
        regret_list.append(inst_regret)
        
        cumulative_regret += inst_regret
        cum_regret.append(cumulative_regret)
        
        
        avg_reg = cumulative_regret / (i + 1)
        avg_cum_regret.append(avg_reg)
        
        # Sample the predicted environment based on the estimated probabilities
        predicted_env = np.argmax(probabilities)
        # Check if the predicted environment matches the actual environment
        if actual_env == predicted_env:
            correct_identifications_count += 1
        
        correct_identifications.append(correct_identifications_count / (i + 2))  # Update accuracy
        
        delay = env.sample(action)
        
        
        reward = -delay
        select_action_class.track_response_time(delay)
        
        selected_actions.append(action)
        rewards.append(reward)
        res.append(delay)
        average_response_time.append(sum(res) / len(res))
        cum_response_time.append(sum(res))
        reward_log.append(reward)
        
        
        select_action_class.update_q_values(action, reward, delay)
        select_action_class.agent_low_workload.update(action, reward, delay, p_e1=select_action_class.p_e1, p_e2=select_action_class.p_e2,p_e3=select_action_class.p_e3)
        select_action_class.agent_mic_workload.update(action, reward, delay, p_e1=select_action_class.p_e1, p_e2=select_action_class.p_e2,p_e3=select_action_class.p_e3)
        select_action_class.agent_high_workload.update(action, reward, delay, p_e1=select_action_class.p_e1, p_e2=select_action_class.p_e2,p_e3=select_action_class.p_e3)
        select_action_class.update_environment_probabilities(action, reward)
        env.step()

        
    return {
        "rewards": rewards,
        "cum_response_time": cum_response_time,
        "average_response_time": average_response_time,
        "res": res,
        "best_chosen": best_chosen_fraction,
        "selected_actions": selected_actions,
        "env_keys": env_keys,
        "reward_log": reward_log,
        "correct_identifications": correct_identifications,
        "bo_selection_log": select_action_class.agent_low_workload.bo_selection_log,
        "average_cumulative_regret": avg_cum_regret
        #"c_value_log": select_action_class.agent_low_workload.c_value_log
    }


   
def do_run_UCB(env, 
           select_action_class, 
           n_steps): 
    selected_actions, rewards, res, Nom_norm, cum_response_time, cum_cost, average_response_time, env_keys, best_chosen = [], [],[],[],[],[],[],[],[]
    
    reward_log = []
    
    best_chosen_fraction = []
    
    regret_list, cum_regret, avg_cum_regret = [], [], []
    cumulative_regret = 0.0
    
    for i in range(n_steps):
    
        action = select_action_class.select_action()
        best_action = env.best_action()
        best_action_chosen = action == best_action
        best_chosen.append(best_action_chosen)
        
        fraction_chosen = np.mean(best_chosen)
        best_chosen_fraction.append(fraction_chosen)
        
        actual_env = env.current_env()
        
        # Sample response times: one for the chosen action and one for the best action.
        chosen_delay = env.sample(action)
        best_delay = env.sample(best_action)
        
        
        inst_regret = chosen_delay - best_delay
        regret_list.append(inst_regret)
        
        cumulative_regret += inst_regret
        cum_regret.append(cumulative_regret)
        
        
        avg_reg = cumulative_regret / (i + 1)
        avg_cum_regret.append(avg_reg)
        
        delay = env.sample(action)
        select_action_class.track_response_time(delay)
        reward = - delay
        
        selected_actions.append(action)
        rewards.append(reward)
        res.append(delay)
        
        average_response_time.append(sum(res)/len(res))
        
        cum_response_time.append(sum(res))
        reward_log.append(reward)
        

        
        select_action_class.update(action, reward, delay)
        env.step()

    

    return {"rewards": rewards, 
            "cum_response_time": cum_response_time, 
            "average_response_time": average_response_time, 
            "res": res, 
            "best_chosen": best_chosen_fraction,
            "selected_actions": selected_actions,
            "env_keys": env_keys,
           "reward_log": reward_log,
           "average_cumulative_regret": avg_cum_regret
           }












            