import torch, random
import torch.nn as nn
from torch.distributions import Categorical, Normal
from torch import distributions

class FeatureExtractor(nn.Module):
    """Critic model
        Parameters:
              args (object): Parameter class
    """

    def __init__(self, obs_dim, feature_topology, feature_type):
        super(FeatureExtractor, self).__init__()

        self.feature_type = feature_type

        if feature_type == 'fnn':

            # Construct Hidden Layer 1 with state
            self.f1 = nn.Linear(obs_dim, feature_topology[0])
            self.f2 = nn.Linear(feature_topology[0], feature_topology[-1])

        else:
            Exception('Unknown Feature type')

    def forward(self, obs):
        """Method to forward propagate through the critic's graph
             Parameters:
                   input (tensor): states
                   input (tensor): actions
             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value
         """

        ###### Compute Hidden Activation h ####
        h = torch.relu(self.f1(obs))
        h = torch.relu(self.f2(h))

        return h


class DDQN(nn.Module):

    """Critic model

        Parameters:
              args (object): Parameter class

    """

    def __init__(self, state_dim, action_dim, hidden_size,  epsilon_start=1.0, epsilon_end=0.03, epsilon_decay_frames=25000):
        super(DDQN, self).__init__()
        self.action_dim = action_dim


        ######################## Q1 Head ##################
        # Construct Hidden Layer 1 with state
        self.f1 = nn.Linear(state_dim, hidden_size)
        #self.q1ln1 = nn.LayerNorm(l1)

        #Hidden Layer 2
        self.f2 = nn.Linear(hidden_size, hidden_size)
        #self.q1ln2 = nn.LayerNorm(l2)

        #Value
        self.val = nn.Linear(hidden_size, 1)


        #Advantages
        self.adv = nn.Linear(hidden_size, action_dim)

        #Epsilon Decay
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay_frames



    def clean_action(self, obs, return_only_action=True):
        """Method to forward propagate through the critic's graph

             Parameters:
                   input (tensor): states
                   input (tensor): actions

             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value



         """

        ###### Feature ####
        info = torch.selu(self.f1(obs))
        info = torch.selu(self.f2(info))

        val = self.val(info)
        adv = self.adv(info)

        logits = val + adv - adv.mean()
        if return_only_action:
            return logits.argmax(1)
        else:
            return logits.argmax(1), None, logits

    def noisy_action(self, obs, return_only_action=True):
        _, _, logits = self.clean_action(obs, return_only_action=False)

        # dist = GumbelSoftmax(temperature=1, logits=q)
        # action = dist.sample()
        # #action = q.argmax(1)

        if random.random() < self.epsilon:
            action = torch.Tensor([random.randint(0, self.action_dim-1)]).int()
            if self.epsilon > self.epsilon_end:
                self.epsilon -= self.epsilon_decay_rate
        else:
            action = logits.argmax(1)


        if return_only_action:
            return action

        #log_prob = dist.log_prob(action)

        #print(action[0].detach().item(), log_prob[0].detach().item())
        return action, None, logits


class MaxminDQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size,  epsilon_start=1.0, epsilon_end=0.03, epsilon_decay_frames=25000):
        super(MaxminDQN, self).__init__()
        self.action_dim = action_dim


        ######################## Q1 Head ##################
        # Construct Hidden Layer 1 with state
        self.f1 = nn.Linear(state_dim, hidden_size)
        #self.q1ln1 = nn.LayerNorm(l1)

        #Hidden Layer 2
        self.f2 = nn.Linear(hidden_size, hidden_size)
        #self.q1ln2 = nn.LayerNorm(l2)
        
        self.heads = nn.ModuleList([])
        for _ in range(2):
            self.heads.append(nn.Linear(hidden_size, action_dim))

        # #First Head
        # self.second_head = nn.Linear(hidden_size, action_dim)

        #Epsilon Decay
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay_frames
    
    def forward(self, state):
        x = torch.relu(self.f1(state))
        x = torch.relu(self.f2(x))
        head_output = []
        for i in range(len(self.heads)):
            head_output.append(self.heads[i](x))
        return head_output
        # first_head  = self.first_head(x)
        # second_head = self.second_head(x)
        # return first_head, second_head

    def clean_action(self, obs, return_only_action=True):
        head_output = self.forward(obs)
        q_min = head_output[0]
        for i in range(1, len(head_output)):
            q_min = torch.min(q_min, head_output[i])
        if return_only_action:
            return q_min.argmax(1)
        else:
            return q_min.argmax(1), None, head_output
    
    def noisy_action(self, obs, return_only_action=True):
        action, _, q_values = self.clean_action(obs, return_only_action=False)
        if random.random() < self.epsilon:
            action = torch.Tensor([random.randint(0, self.action_dim-1)]).int()
            if self.epsilon > self.epsilon_end:
                self.epsilon -= self.epsilon_decay_rate
        
        if return_only_action:
            return action
        return action, None, q_values




class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
    '''
    A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
    Explanation http://amid.fish/assets/gumbel.html
    NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
    Papers:
    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
    [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
    '''

    def sample(self, sample_shape=torch.Size()):
        '''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
        u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
        noisy_logits = self.logits - torch.log(-torch.log(u))
        return torch.argmax(noisy_logits, dim=-1)

    def log_prob(self, value):
        '''value is one-hot or relaxed'''
        if value.shape != self.logits.shape:
            value = F.one_hot(value.long(), self.logits.shape[-1]).float()
            assert value.shape == self.logits.shape
        return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)