import torch
import torch.nn.functional as F

from torch import optim
from torch import nn
from utils.snippets import torch_one_hot, convert_args_to_tensor, huber_loss
from copy import deepcopy

class DQNCritic:
    def __init__(self, hparams, optimizer_spec, **kwargs):
        super().__init__(**kwargs)
        # self.env_name = hparams['env_name']
        self.ob_dim = hparams['ob_dim']
        self.device = hparams['device']
        if isinstance(self.ob_dim, int):
            self.input_shape = self.ob_dim
        else:
            self.input_shape = hparams['input_shape']
        
        self.in_channels = hparams['in_channels']
        self.ac_dim = hparams['ac_dim']
        self.double_q = hparams['double_q']
        self.grad_norm_clipping = hparams['grad_norm_clipping']
        self.gamma = hparams['gamma']
        self.algorithm = hparams['algorithm']
        self.exploration_strategy: str = hparams['exploration_strategy']
        self.use_normalization_scheme = hparams['use_normalization_scheme']
        
        self.optimizer_spec = optimizer_spec # TODO Learning-rate scheduling and Optimizer Option should be added in the future
        self._build(hparams['q_func'])

    def _build(self, q_func):

        #####################
        if self.in_channels is not None:
            q_func_modif = lambda input_shape, ac_dim: q_func(input_shape, ac_dim, self.in_channels)
        else:
            q_func_modif = q_func

        # q values, created with the placeholder that holds CURRENT obs (i.e., t)
        self.q_t_values = q_func_modif(self.input_shape, self.ac_dim).to(self.device)

        #####################

        # target q values, created with the placeholder that holds NEXT obs (i.e., t+1)
        self.q_tp1_values = q_func_modif(self.input_shape, self.ac_dim).to(self.device)

        # train_fn will be called in order to train the critic (by minimizing the TD error)
        self.optimizer = self.optimizer_spec.constructor(self.q_t_values.parameters()) #TODO: hyperparameters can be changed and learning schedule can be added


    def __calc_target_vals(self, next_ob_no, re_n, terminal_n, e_value):

        with torch.no_grad():
            # target q values, created with the placeholder that holds NEXT obs (i.e., t+1)
            q_tp1_values = self.q_tp1_values(next_ob_no)

            if self.double_q:  # TODO: Add expected sarsa for double q-learning
                # In double Q-learning, the best action is selected using the Q-network that
                # is being updated, but the Q-value for this action is obtained from the
                # target Q-network. See page 5 of https://arxiv.org/pdf/1509.06461.pdf for more details.
                q_tp1_values_no = self.q_t_values(next_ob_no)
                argmax_slices = torch.stack([torch.arange(q_tp1_values_no.shape[0], device=self.device), torch.argmax(q_tp1_values_no, dim=1)]).cpu().numpy().tolist()
                q_tp1 = q_tp1_values[argmax_slices]

                next_state_update = q_tp1
            
            else:

                # q values of the next timestep
                q_tp1, max_indices = torch.max(q_tp1_values, dim=1)
                argmax_slices = torch.stack([torch.arange(q_tp1.shape[0], device=self.device), max_indices]).cpu().numpy().tolist()
                
                if self.algorithm == 'expected_sarsa': # TODO: this version of Expected Sarsa and Q-learning does not work with continues action space. So, we should consider implementing it in the future.
  
                    if self.exploration_strategy == 'epsilon-greedy':
                    
                        eps: float = e_value
            
                        # Calculating action probabilites
                        action_probs = torch.zeros_like(q_tp1_values)+eps/self.ac_dim
                        action_probs[argmax_slices] += 1 - eps
                        
                    elif self.exploration_strategy == 'resmax': #TODO: Currently this only works for descrete spaces, we should think about using this for continues spaces
                        eta: float = e_value
                        if self.use_normalization_scheme:
                            q_tp1_min, min_indices = torch.min(q_tp1_values, dim=1)
                        #    print((((self.ac_dim-1)*(q_tp1 - q_tp1_min))[None]).shape)
                        #    print(((1/eta)*(q_tp1[None].T - q_tp1_values)).shape)
                        #    print('------------------------------')
                        #    print((1/eta)*(q_tp1[None].T - q_tp1_values))
                        #    print(((self.ac_dim-1)*(q_tp1 - q_tp1_min))[None].T)
                        #    print((((self.ac_dim-1)*(q_tp1 - q_tp1_min))[None].T + (1/eta)*(q_tp1[None].T - q_tp1_values)))
                            action_probs = 1/((self.ac_dim)*torch.maximum((q_tp1 - q_tp1_min), torch.ones_like(q_tp1))[None].T + (1/eta)*(q_tp1[None].T - q_tp1_values))
                            action_probs[argmax_slices] += 1 - action_probs.sum(1)
                        else:
                            action_probs = 1/(self.ac_dim + (1/eta)*(q_tp1[None].T - q_tp1_values))
                            # Calculating action probabilites
                            action_probs[argmax_slices] += 1 - action_probs.sum(1)
                         
                    elif self.exploration_strategy == 'softmax':

                        temperature: float = e_value #TODO: this hypterparameter should be added to soft-max
                        
                        # Calculating action probabilites
                        action_probs = torch.softmax((1/temperature)*q_tp1_values, dim=1)
              
                    else:
                        raise ValueError('This exploration strategy does not exist: {}'.format(self.exploration_strategy))
                   
                    # Calculating the next state update
                    # print('q_tp1_values: ', q_tp1_values)
                    # print('action probs: ', action_probs)
                    next_state_update = (action_probs*q_tp1_values).sum(1)
                    # print('next_state_update: ', next_state_update.shape) 
                elif self.algorithm == 'q-learning':
                   next_state_update = q_tp1

            # Caclulating the target
            target = re_n + self.gamma * next_state_update * (1 - terminal_n)
            # calculate the targets for the Bellman error
            return target

    def update_target_network(self):
        self.q_tp1_values.load_state_dict(self.q_t_values.state_dict())

    @convert_args_to_tensor()
    def update(self, ob_no, next_ob_no, act_t_ph, re_n, terminal_n, lr, e_value):
        ob_no, next_ob_no, act_t_ph, re_n, terminal_n = \
            [x.to(self.device) for x in [ob_no, next_ob_no, act_t_ph, re_n, terminal_n]]

        # setting the learning-rate value
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        target_q_t = self.__calc_target_vals(next_ob_no, re_n, terminal_n, e_value)

        q_t = torch.sum(self.q_t_values(ob_no) * torch_one_hot(act_t_ph, self.ac_dim), dim=1)

        #####################

        # compute the Bellman error (i.e. TD error between q_t and target_q_t)
        total_error = torch.mean(huber_loss(q_t, target_q_t))
        td_error = ((target_q_t - q_t)**2).mean().detach().cpu().numpy() # INFO: used only for logging purposes

        #####################
        
        # train_fn will be called in order to train the critic (by minimizing the TD error)
        self.optimizer.zero_grad()
        total_error.backward()
        nn.utils.clip_grad_norm_(self.q_t_values.parameters(), self.grad_norm_clipping)

        self.optimizer.step()
        weights_change = 0 

        return total_error.detach().cpu().numpy(), td_error, weights_change
