import numpy as np
from gym import spaces


class EnvironmentsWithFixedAgent():
    def __init__(self, env, fixed_agent_index='', fixed_agent=None, dt_multiplier=1):
        self.env = env
        self.fixed_agent_index = fixed_agent_index
        self.fixed_agent = fixed_agent
        self.dt_multiplier = dt_multiplier
        
        self.metadata = {}
        self.unwrapped = self
        self.reward_range = (-float("inf"), float("inf"))
        
        self.state_dim = self.env.state_dim
        if self.fixed_agent_index=='u':
            self.action_dim = self.env.v_action_dim
            self.action_min = self.env.v_action_min
            self.action_max = self.env.v_action_max
        elif self.fixed_agent_index=='v':
            self.action_dim = self.env.u_action_dim
            self.action_min = self.env.u_action_min
            self.action_max = self.env.u_action_max
        else:
            print('Unexpected fixed_agent_index')
        
        self.observation_space = self.get_space_box(-float("inf"), float("inf"), self.state_dim)
        self.action_space = self.get_space_box(self.action_min, self.action_max, self.action_dim)
        
        self.total_rewards = []
        self.total_timesteps = []
        self.u_best_session = None
        self.v_best_session = None
        self.current_timesteps = 0
    
    
    def get_space_box(self, var_min, var_max, var_dim):
        if var_dim == 1:
            return spaces.Box(var_min[0], var_max[0], shape=(1,))
        else:
            return spaces.Box(var_min, var_max, shape=(var_dim,))
    

    def reset(self):
        self.state = self.env.reset()
        self.current_session = {'states':[], 'u_actions':[], 'v_actions':[], 'rewards': [], 'dones': []}
        if 'reset' in dir(self.fixed_agent):
            self.fixed_agent.reset()
        return self.state
    
    
    def step(self, action):
        
        self.current_timesteps += 1
        
        reward = 0
        for _ in range(self.dt_multiplier):
            
            if self.fixed_agent_index == 'u':
                u_action = self.get_action(self.env.state)
                v_action = action
            elif self.fixed_agent_index == 'v':
                u_action = action
                v_action = self.get_action(self.env.state)
                
            _, inner_reward, done, _ = self.env.step(u_action, v_action)
            reward += inner_reward
            
            if done:
                break
                
        self.current_session['states'].append(self.env.state)
        self.current_session['u_actions'].append(u_action)
        self.current_session['v_actions'].append(v_action)
        self.current_session['rewards'].append(reward)
        self.current_session['dones'].append(done)
        
        if done:
            total_reward = np.sum(self.current_session['rewards'])
            self.total_rewards.append(total_reward)
            self.total_timesteps.append(self.current_timesteps)
            if total_reward <= np.min(self.total_rewards):
                self.u_best_session = self.current_session
            if total_reward >= np.max(self.total_rewards):
                self.v_best_session = self.current_session

        if self.fixed_agent_index == 'u':
            return self.env.state, reward, done, {}
        elif self.fixed_agent_index == 'v':
            return self.env.state, -reward, done, {}
    
    
    def get_action(self, state):
        if 'predict' in dir(self.fixed_agent):
            action, _ = self.fixed_agent.predict(state)
        else:
            action = self.fixed_agent.get_action(state)
        return action
