import numpy as np

class Env:
    def __init__(self, mus, sigmas): 
        self.mus = mus
        self.sigmas = sigmas
        self.env_key = 0
        
    def sample(self, action):
        mu, sigma = self.mus[action], self.sigmas[action]
        response_time = np.random.normal(mu, sigma)
        return response_time
    
    def step(self, **kwargs):
        pass
    
    def best_action(self):
        return np.argmin(self.mus)

class DynamicEnvOrdered:
    def __init__(self, mus_all, sigmas_all, change_every):
        self.mus_all = mus_all
        self.sigmas_all = sigmas_all
        self.change_every = change_every
        self.timestep = 0
        self.env_key = -1
        self.reselect_env()
        
    
    def sample(self, action):
        mu, sigma = self.mus[action], self.sigmas[action]
        response_time = np.random.normal(mu, sigma)
        return response_time
        
    def reselect_env(self):
        self.env_key += 1
        
        if self.env_key>=len(self.mus_all)-1:
            self.env_key = self.env_key%len(self.mus_all)
        self.mus, self.sigmas = self.mus_all[self.env_key], self.sigmas_all[self.env_key]
        
    def best_action(self):
        return np.argmin(self.mus)
        
    def step(self, **kwargs):
        self.timestep+=1 
        
        if self.timestep%self.change_every==0:
            self.reselect_env()
            
    def current_env(self):
        return self.env_key
    
    
    
# incremental Env.

class DynamicEnvOrdered_Incremental:
    def __init__(self, mus_all, sigmas_all,  transition_steps=2000):
        self.mus_all = np.array(mus_all)
        self.sigmas_all = np.array(sigmas_all)
        self.transition_steps = transition_steps
        self.timestep = 0
        self.env_key = 0
        self.next_env_key = 1
        self.progress = 0.0
        self.mus = np.array(self.mus_all[self.env_key])
        self.sigmas = np.array(self.sigmas_all[self.env_key])
        self.mus_next = np.array(self.mus_all[self.next_env_key])
        self.sigmas_next = np.array(self.sigmas_all[self.next_env_key])
        
    def sample(self, action):
        mu, sigma = self.mus[action], self.sigmas[action]
        response_time = np.random.normal(mu, sigma)
        return response_time
        
    def reselect_env(self):
        self.env_key = self.next_env_key
        self.next_env_key = (self.next_env_key + 1) % len(self.mus_all)
        self.mus_next = np.array(self.mus_all[self.next_env_key])
        self.sigmas_next = np.array(self.sigmas_all[self.next_env_key])
        
    def best_action(self):
        return np.argmin(self.mus)
        
    def step(self, **kwargs):
        self.timestep += 1
        self.progress += 1 / self.transition_steps
        
        if self.progress >= 1.0:
            self.reselect_env()
            self.progress = 0.0
        
        self.mus = (1 - self.progress) * np.array(self.mus_all[self.env_key]) + self.progress * self.mus_next
        self.sigmas = (1 - self.progress) * np.array(self.sigmas_all[self.env_key]) + self.progress * self.sigmas_next
            
    def current_env(self):
        return self.env_key
    

##Random Env

class DynamicEnvRandom:
    """
    A dynamic environment where the parameters (mus and sigmas) are randomly selected 
    with a hazard probability at each timestep.
    
    Parameters:
    - mus_all: List of arrays containing mean values for each action across different environments
    - sigmas_all: List of arrays containing standard deviation values for each action
    - hazard: Probability of environment change at each timestep (default: 0.001)
    
    Note:
    With hazard probability, the environment parameters change randomly at each step,
    allowing for more controlled testing of adaptability to unpredictable changes.
    """
    def __init__(self, mus_all, sigmas_all, hazard=1.0):
        self.mus_all = mus_all
        self.sigmas_all = sigmas_all
        self.hazard = hazard
        self.timestep = 0
        # Randomly select the initial environment
        self.env_key = np.random.choice(len(mus_all))
        self.mus = self.mus_all[self.env_key]
        self.sigmas = self.sigmas_all[self.env_key]
        
    def sample(self, action):
        mu, sigma = self.mus[action], self.sigmas[action]
        response_time = np.random.normal(mu, sigma)
        return response_time
    
    def best_action(self):
        return np.argmin(self.mus)
        
    def step(self, **kwargs):
        self.timestep += 1
        # Only change environment with hazard probability
        if np.random.random() < self.hazard:
            # Randomly select a new environment
            self.env_key = np.random.choice(len(self.mus_all))
            self.mus = self.mus_all[self.env_key]
            self.sigmas = self.sigmas_all[self.env_key]
            # Uncomment for debugging:
            # print(f"Step {self.timestep}: Environment changed to {self.env_key}, mus: {self.mus}, sigmas: {self.sigmas}")
        
    def current_env(self):
        return self.env_key

    
#Mix Env
class DynamicEnvOrdered_Mix_Env:
    def __init__(self, mus_env1, sigmas_env1, mus_env2, sigmas_env2, mus_env3, sigmas_env3, alpha=0.7, bet=0.3):
        self.mus_env1 = mus_env1
        self.sigmas_env1 = sigmas_env1
        self.mus_env2 = mus_env2
        self.sigmas_env2 = sigmas_env2
        self.mus_env3 = mus_env3
        self.sigmas_env3 = sigmas_env3
        self.alpha = alpha  
        self.beta = bet    
        
        
        self.mus = (self.alpha * np.array(self.mus_env1) + 
                    self.beta * np.array(self.mus_env2) + 
                    (1 - self.alpha - self.beta) * np.array(self.mus_env3))
        
        self.sigmas = (self.alpha * np.array(self.sigmas_env1) + 
                       self.beta * np.array(self.sigmas_env2) + 
                       (1 - self.alpha - self.beta) * np.array(self.sigmas_env3))
        
    def sample(self, action):
        mu, sigma = self.mus[action], self.sigmas[action]
        response_time = np.random.normal(mu, sigma)
        return  response_time
        
    def best_action(self):
        return np.argmin(self.mus)
        
    def step(self, **kwargs):
        pass
    
    def current_env(self):
        return {"alpha": self.alpha, "beta": self.beta}