
import numpy as np
from scipy.stats import norm

class MultiAgentSystem:
    def __init__(self, agent_type, agent_params, **kwargs):
        self.agent_low_workload = agent_type(**agent_params)
        self.agent_mic_workload = agent_type(**agent_params)
        self.agent_high_workload = agent_type(**agent_params)
        self.p_e1 = 1/3
        self.p_e2 = 1/3
        self.p_e3 = 1/3
        
        self.min_val = 0.000001
        self.default_delta = 0.02
        self.high_workload_delta = 5.0  
    
    def update_environment_probabilities(self, action, reward, delta=0.09):
        mean_e1 = self.agent_low_workload.q_values[action]
        var_e1 = self.agent_low_workload.calculate_variance(action)
        
        mean_e2 = self.agent_mic_workload.q_values[action]
        var_e2 = self.agent_mic_workload.calculate_variance(action)
        
        mean_e3 = self.agent_high_workload.q_values[action]
        var_e3 = self.agent_high_workload.calculate_variance(action)
        
        p_o_given_e1 = norm.cdf(reward + delta, mean_e1, np.sqrt(var_e1)) - norm.cdf(reward - delta, mean_e1, np.sqrt(var_e1))
        p_o_given_e2 = norm.cdf(reward + delta, mean_e2, np.sqrt(var_e2)) - norm.cdf(reward - delta, mean_e2, np.sqrt(var_e2))
        p_o_given_e3 = norm.cdf(reward + delta, mean_e3, np.sqrt(var_e3)) - norm.cdf(reward - delta, mean_e3, np.sqrt(var_e3))
        
        E1 = p_o_given_e1 * self.p_e1
        E2 = p_o_given_e2 * self.p_e2
        E3 = p_o_given_e3 * self.p_e3
        
        if E1 <= self.min_val:
            E1 = self.min_val
        if E2 <= self.min_val:
            E2 = self.min_val
        if E3 <= self.min_val:
            E3 = self.min_val
        
        self.p_e1 = E1 / (E1 + E2 + E3)
        self.p_e2 = E2 / (E1 + E2 + E3)
        self.p_e3 = E3 / (E1 + E2 + E3)
    
    def select_action(self):
        action_low = self.agent_low_workload.select_action()
        action_mic = self.agent_mic_workload.select_action()
        action_high = self.agent_high_workload.select_action()
        return action_low, action_mic, action_high
    
    def track_response_time(self, delay):
        delay_low = self.agent_low_workload.track_response_time(delay)
        delay_mic = self.agent_mic_workload.track_response_time(delay)
        delay_high = self.agent_high_workload.track_response_time(delay)
        return delay_low, delay_mic, delay_high
    
    def update_q_values(self, action, reward, delay):
        self.agent_low_workload.update(action, reward, delay, env_prob=self.p_e1)
        self.agent_mic_workload.update(action, reward, delay, env_prob=self.p_e2)
        self.agent_high_workload.update(action, reward, delay, env_prob=self.p_e3)


class MultiAgentSystem_BO:
    def __init__(self, agent_type, agent_params, **kwargs):
        self.agent_low_workload = agent_type(**agent_params)
        self.agent_mic_workload = agent_type(**agent_params)
        self.agent_high_workload = agent_type(**agent_params)
        self.p_e1 = 1/3
        self.p_e2 = 1/3
        self.p_e3 = 1/3
        
        self.min_val = 0.000001 
    
    def update_environment_probabilities(self, action, reward, delta=0.09):
        mean_e1 = self.agent_low_workload.q_values[action]
        var_e1 = self.agent_low_workload.calculate_variance(action)
        
        mean_e2 = self.agent_mic_workload.q_values[action]
        var_e2 = self.agent_mic_workload.calculate_variance(action)
        
        mean_e3 = self.agent_high_workload.q_values[action]
        var_e3 = self.agent_high_workload.calculate_variance(action)
        
        p_o_given_e1 = norm.cdf(reward + delta, mean_e1, np.sqrt(var_e1)) - norm.cdf(reward - delta, mean_e1, np.sqrt(var_e1))
        p_o_given_e2 = norm.cdf(reward + delta, mean_e2, np.sqrt(var_e2)) - norm.cdf(reward - delta, mean_e2, np.sqrt(var_e2))
        p_o_given_e3 = norm.cdf(reward + delta, mean_e3, np.sqrt(var_e3)) - norm.cdf(reward - delta, mean_e3, np.sqrt(var_e3))
        
        E1 = p_o_given_e1 * self.p_e1
        E2 = p_o_given_e2 * self.p_e2
        E3 = p_o_given_e3 * self.p_e3
        
        if E1 <= self.min_val:
            E1 = self.min_val
        if E2 <= self.min_val:
            E2 = self.min_val
        if E3 <= self.min_val:
            E3 = self.min_val
        
        self.p_e1 = E1 / (E1 + E2 + E3)
        self.p_e2 = E2 / (E1 + E2 + E3)
        self.p_e3 = E3 / (E1 + E2 + E3)
    
    def select_action(self):
        action_low = self.agent_low_workload.select_action()
        action_mic = self.agent_mic_workload.select_action()
        action_high = self.agent_high_workload.select_action()
        return action_low, action_mic, action_high
    
    def track_response_time(self, delay):
        delay_low = self.agent_low_workload.track_response_time(delay)
        delay_mic = self.agent_mic_workload.track_response_time(delay)
        delay_high = self.agent_high_workload.track_response_time(delay)
        return delay_low, delay_mic, delay_high
    
    def update_q_values(self, action, reward, delay):
        self.agent_low_workload.update(action, reward, delay, env_prob=self.p_e1, p_e1=self.p_e1, p_e2=self.p_e2,p_e3=self.p_e3 )
        self.agent_mic_workload.update(action, reward, delay, env_prob=self.p_e2, p_e1=self.p_e1, p_e2=self.p_e2,p_e3=self.p_e3)
        self.agent_high_workload.update(action, reward, delay, env_prob=self.p_e3, p_e1=self.p_e1, p_e2=self.p_e2,p_e3=self.p_e3)
