import numpy as np

from utils.replay_buffer import MemoryOptimizedReplayBuffer
from utils.policies import ArgMaxPolicy
from utils.critics import DQNCritic
from scipy.special import softmax
from utils.snippets import RootFinder 

class DQNAgent(object):
    def __init__(self, env, agent_params, preprocess_obs=None, render=False):

        self.env = env
        self.render = render
        self.agent_params: dict = agent_params
        self.batch_size: int = agent_params['batch_size']
        self.last_obs: np.ndarray = self.env.reset()

        self.num_actions: int = agent_params['ac_dim']
        self.learning_starts: int = agent_params['learning_starts']
        self.learning_freq: int = agent_params['learning_freq']
        self.target_update_freq: int = agent_params['target_update_freq']

        self.replay_buffer_idx: int = None
        self.exploration = agent_params['exploration_schedule']
        self.episode_based_exploration = agent_params['episode_based_exploration']
        self.exploration_strategy: str = agent_params['exploration_strategy']
        self.optimizer_spec = agent_params['optimizer_spec']

        self.critic: DQNCritic = DQNCritic(agent_params, self.optimizer_spec)
        self.actor: ArgMaxPolicy = ArgMaxPolicy(self.critic, device=agent_params['device'])
        
        self.use_normalization_scheme = self.agent_params['use_normalization_scheme']
        self.g_bound = self.agent_params['g_bound']
        
        self.td_error_mg = self.agent_params['td_error_mg'] # TD Error Moving Average
        self.td_error_mg_lr = self.agent_params['td_error_mg_lr'] # TD Error Moving Average Learning Rate
        self.td_error_mg_epsilon = self.agent_params['td_error_mg_epsilon'] # TD Error Epsilon
        self.td_error_scheduling = self.agent_params['td_error_scheduling'] # A binary variable showing whether to do td error based scheduling

        self.preprocess_obs = preprocess_obs

        self.t: int = 0
        self.episode_num = 1
        self.num_param_updates: int = 0
        self.is_image_ob: bool = not isinstance(self.agent_params['ob_dim'], int)

        self.replay_buffer = MemoryOptimizedReplayBuffer(agent_params['replay_buffer_size'], agent_params['frame_history_len'], is_image=self.is_image_ob)
        
        self.root_finder = RootFinder()

    def step_env(self):

        """
            Step the env and store the transition

            At the end of this block of code, the simulator should have been
            advanced one step, and the replay buffer should contain one more transition.

            Note that self.last_obs must always point to the new latest observation.
        """

        # if self.is_image_ob:
            # # Making the last observation channel first
            # self.last_obs = np.transpose(self.last_obs, (2, 0, 1))
        
        if self.preprocess_obs is not None:
           self.last_obs = self.preprocess_obs(self.last_obs)
        
        # store the latest observation into the replay buffer
        self.replay_buffer_idx = self.replay_buffer.store_frame(self.last_obs)

        if self.episode_based_exploration: 
            exploration_step = self.episode_num
        else:
            exploration_step = self.t

        if self.exploration_strategy == 'epsilon-greedy':
        
            eps: float = self.exploration.value(exploration_step)
            # use epsilon greedy exploration when selecting action
            perform_random_action = np.random.random() < eps or self.t < self.learning_starts

            if perform_random_action:
                action: int = np.random.randint(self.num_actions)
            else:
                # query the policy to select action
                enc_last_obs: np.ndarray = self.replay_buffer.encode_recent_observation()
                enc_last_obs = enc_last_obs[None, :]

                # query the policy with enc_last_obs to select action
                action: int = self.actor.get_action(enc_last_obs.astype(np.float32))
                action = action[0]

        elif self.exploration_strategy == 'resmax': #TODO: Currently this only works for descrete spaces, we should think about using this for continues spaces
            # query the policy to select action

            eta: float = self.exploration.value(exploration_step)
            
            enc_last_obs = self.replay_buffer.encode_recent_observation()
            enc_last_obs = enc_last_obs[None, :]

            # query the policy with enc_last_obs to select action
            action_values = self.actor.get_action_values(enc_last_obs.astype(np.float32))[0]
            action_probs: np.ndarray = np.zeros(self.num_actions)
            b: int = np.argmax(action_values)
            
            # Calculating action probablities
            if self.td_error_scheduling:
                eta = eta/(self.td_error_mg + self.td_error_mg_epsilon) 

            if self.use_normalization_scheme:
                normalization_factor = action_values[b] - (action_values).min()
                #if min_action_value < 0:
                #    action_values += np.abs(min_action_value)
                action_probs = 1/((self.num_actions)*max(normalization_factor, 1) + (1/eta)*(action_values[b] - action_values))
                #ction_probs = 1/(self.num_actions + eta/(self.g_bound[1] - self.g_bound[0])*(action_values[b] - action_values))
            else:
                action_probs = 1/(self.num_actions + (1/eta)*(action_values[b] - action_values))

            action_probs[b] += 1 - action_probs.sum()

            action: int = np.random.choice(self.num_actions, p=action_probs)

        elif self.exploration_strategy == 'softmax':

            temperature: float = self.exploration.value(exploration_step)
            
            # Calculating action probablities
            if self.td_error_scheduling:
                temperature = temperature/(self.td_error_mg + self.td_error_mg_epsilon) 

            # query the policy to select action
            enc_last_obs = self.replay_buffer.encode_recent_observation()
            enc_last_obs = enc_last_obs[None, :]

            action_values = self.actor.get_action_values(enc_last_obs.astype(np.float32))[0]
            # query the policy with enc_last_obs to select action
            if self.use_normalization_scheme:
                #p = softmax(temperature*action_values/(self.g_bound[1] - self.g_bound[0]))
                min_action_value = (action_values).min()
                if min_action_value < 0:
                    action_values += np.abs(min_action_value)
                b: int = np.argmax(action_values)
                p = softmax((1/temperature)*action_values/(action_values[b]))
                p = p/np.sum(p) # Normalizing in case that softmax function ran to some numerical instability        

            else:
                p = softmax((1/temperature)*action_values)
                p = p/np.sum(p) # Normalizing in case that softmax function ran to some numerical instability        

            action = np.random.choice(self.num_actions, p=p)
        
        elif self.exploration_strategy == 'mellowmax':
            

            omega: float = self.exploration.value(self.t)
            
            # Calculating action probablities
#             if self.td_error_scheduling:
                # temperature = temperature/(self.td_error_mg + self.td_error_mg_epsilon) 

            # query the policy to select action
            enc_last_obs = self.replay_buffer.encode_recent_observation()
            enc_last_obs = enc_last_obs[None, :]
            
            action_values = self.actor.get_action_values(enc_last_obs.astype(np.float32))[0]
            # Root finding for computing temperature
            temperature = self.root_finder.mellow_max_root_finder(action_values, omega)
            if self.t % 1000 == 0:
                print('temperature: ', temperature)

            # query the policy with enc_last_obs to select action
            if self.use_normalization_scheme:
                min_action_value = (action_values).min()
                if min_action_value < 0:
                    action_values += np.abs(min_action_value)
                b: int = np.argmax(action_values)
                p = softmax((1/temperature)*action_values/(action_values[b]))
                #p = softmax(temperature*action_values/(self.g_bound[1] - self.g_bound[0]))
                p = p/np.sum(p) # Normalizing in case that softmax function ran to some numerical instability        
            else:
                p = softmax((1/temperature)*action_values)
                p = p/np.sum(p) # Normalizing in case that softmax function ran to some numerical instability        

            action = np.random.choice(self.num_actions, p=p)
        else:
            raise ValueError('This exploration strategy does not exist: {}'.format(self.exploration_strategy))

        if self.render:
            self.env.render()

        # take a step in the environment using the action from the policy
        self.last_obs, reward, done, info = self.env.step(action)

        if done:
            self.episode_num += 1

        # store the result of taking this action into the replay buffer
        self.replay_buffer.store_effect(self.replay_buffer_idx, action, reward, done)

        # if taking this step resulted in done, reset the env (and the latest observation)
        if done:
            obs = self.env.reset()
    
    def sample(self, batch_size):
        if self.replay_buffer.can_sample(self.batch_size):
            return self.replay_buffer.sample(batch_size)
        else:
            return [],[],[],[],[]

    def train(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        """
            Updating/training DQN agent
        """

        loss = 0.0
        td_error = 0.0
        weights_change = 0.0
        if (self.t > self.learning_starts and \
                self.t % self.learning_freq == 0 and \
                self.replay_buffer.can_sample(self.batch_size)):

            # populate all placeholders necessary for calculating the critic's total_error
            feed_dict = {
                'lr': self.optimizer_spec.lr_schedule.value(self.t),
                'e_value': self.exploration.value(self.t),
                'ob_no': ob_no.astype(np.float32),
                'act_t_ph': ac_na.astype(np.int64), # NOTE: changed this to int64, int.long is platform dependent and was giving me issues 
                're_n': re_n,
                'next_ob_no': next_ob_no.astype(np.float32),
                'terminal_n': terminal_n,
            }
            
            # create a LIST of tensors to run in order to
            # train the critic as well as get the resulting total_error
            loss, td_error, weights_change = self.critic.update(**feed_dict)
            # print(self.td_error_mg)
            self.td_error_mg = (1-self.td_error_mg_lr)*self.td_error_mg + td_error*self.td_error_mg_lr
            # Note: remember that the critic's total_error value is what you
            # created to compute the Bellman error in a batch,
            # and the critic's train function performs a gradient step
            # and update the network parameters to reduce that total_error.

            # use sess.run to periodically update the critic's target function
            if self.num_param_updates % self.target_update_freq == 0:
                self.critic.update_target_network()

            self.num_param_updates += 1
        self.t += 1
        return loss, td_error, weights_change

    def offline_eval_episodes(self, episode_num: int, environment_class, seed):

        if not hasattr(self, 'eval_env'):
            self.eval_env = environment_class()
            self.eval_env.seed(seed)

        last_obs = self.eval_env.reset()

        episodic_steps = []
        episodic_returns = []

        for ep_idx in range(episode_num):
            steps = 0
            rewards = []
            done = False

            while not done: 
                last_obs, reward, done, info = self.eval_step(last_obs)

                rewards.append(reward)
                steps += 1

                # if taking this step resulted in done, reset the env (and the latest observation)
                if done:
                    last_obs = self.eval_env.reset()
                    episodic_steps.append(steps)
                    episodic_returns.append(np.sum(rewards))
        
        return episodic_returns, episodic_steps
    
    def eval_step(self, last_obs):
        """
            Step the env for offline evluation
        """

        # if self.is_image_ob:
            # # Making the last observation channel first
            # self.last_obs = np.transpose(self.last_obs, (2, 0, 1))
        
        if self.preprocess_obs is not None:
           last_obs = self.preprocess_obs(last_obs)
        
        enc_last_obs: np.ndarray = last_obs
        enc_last_obs = enc_last_obs[None, :]

        # query the policy with enc_last_obs to select action
        action: int = self.actor.get_action(enc_last_obs.astype(np.float32))
        action = action[0]


#         if self.exploration_strategy == 'epsilon-greedy':
        
            # eps: float = 0.05
            # # use epsilon greedy exploration when selecting action
            # perform_random_action = np.random.random() < eps or self.t < self.learning_starts
            
            # if perform_random_action:
                # action: int = np.random.randint(self.num_actions)
            # else:
                # # query the policy to select action
                # enc_last_obs: np.ndarray = last_obs
                # enc_last_obs = enc_last_obs[None, :]

                # # query the policy with enc_last_obs to select action
                # action: int = self.actor.get_action(enc_last_obs.astype(np.float32))
                # action = action[0]

        # elif self.exploration_strategy == 'resmax': #TODO: Currently this only works for descrete spaces, we should think about using this for continues spaces
            # # query the policy to select action

            # eta: float =  10000

            # enc_last_obs = last_obs
            # enc_last_obs = enc_last_obs[None, :]

            # # query the policy with enc_last_obs to select action
            # action_values = self.actor.get_action_values(enc_last_obs.astype(np.float32))[0]
            # action_probs: np.ndarray = np.zeros(self.num_actions)
            # b: int = np.argmax(action_values)
            
            # # Calculating action probablities
            # action_probs = 1/(self.num_actions + eta*(action_values[b] - action_values))
            # action_probs[b] += 1 - action_probs.sum()

            # action: int = np.random.choice(self.num_actions, p=action_probs)

        # elif self.exploration_strategy == 'softmax':

            # temperature: float = np.log(10000)

            # # query the policy to select action
            # enc_last_obs = last_obs
            # enc_last_obs = enc_last_obs[None, :]

            # # query the policy with enc_last_obs to select action
            # action_values = self.actor.get_action_values(enc_last_obs.astype(np.float32))[0]
            # p: np.ndarray = softmax(temperature*action_values)
            # p = p/np.sum(p) # Normalizing in case that softmax function ran to some numerical instability        

            # action = np.random.choice(self.num_actions, p=p)
        # else:
            # raise ValueError('This exploration strategy does not exist: {}'.format(self.exploration_strategy))

        if self.render:
            self.eval_env.render()

        # take a step in the environment using the action from the policy
        last_eval_obs, reward, done, info = self.eval_env.step(action)
        
        return last_eval_obs, reward, done, info

