
import torch, random
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import RelaxedOneHotCategorical, Categorical


class FeatureExtractor(nn.Module):
    """Critic model
        Parameters:
              args (object): Parameter class
    """

    def __init__(self, obs_dim, feature_topology, feature_type):
        super(FeatureExtractor, self).__init__()

        self.feature_type = feature_type

        if feature_type == 'fnn':

            # Construct Hidden Layer 1 with state
            self.f1 = nn.Linear(obs_dim, feature_topology[0])
            self.f2 = nn.Linear(feature_topology[0], feature_topology[-1])

        else:
            Exception('Unknown Feature type')

    def forward(self, obs):
        """Method to forward propagate through the critic's graph
             Parameters:
                   input (tensor): states
                   input (tensor): actions
             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value
         """

        ###### Compute Hidden Activation h ####
        h = torch.relu(self.f1(obs))
        h = torch.relu(self.f2(h))

        return h


class GumbelPolicy(nn.Module):
    """Critic model
        Parameters:
              args (object): Parameter class
    """

    def __init__(self, obs_dim, action_dim, feature_topology=[256, 186], feature_type='fnn', eps_start=1.0, eps_end=0.05, decay_gen=10000):
        super(GumbelPolicy, self).__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim

        # Feature Extractor Nets
        self.feature_extractor = FeatureExtractor(obs_dim, feature_topology, feature_type)

        # # Value
        # self.val = nn.Linear(feature_topology[-1], 1)

        # Advantages
        self.adv = nn.Linear(feature_topology[-1], action_dim)

        # Temperature
        # self.temperature = nn.Linear(feature_topology[-1], 1)

        # E-Greedy Params
        self.epsilon = torch.nn.Parameter(torch.Tensor([eps_start])).detach()
        self.epsilon_end = torch.nn.Parameter(torch.Tensor([eps_end])).detach()
        decay_rate = (eps_start - eps_end) / decay_gen
        self.epsilon_decay_rate = torch.nn.Parameter(torch.Tensor([decay_rate])).detach()

        # TEMPERATURE SPECIFIC
        self.TEMP_MAX = 20
        self.TEMP_MIN = 0.1

    def take_action(self, obs, return_only_action=True, noise=False):
        """Method to forward propagate through the critic's graph
             Parameters:
                   input (tensor): states
                   input (tensor): actions
             Returns:
                   Q1 (tensor): Qval 1
                   Q2 (tensor): Qval 2
                   V (tensor): Value
         """

        ############ Process Features to compute h  ############
        h = self.feature_extractor.forward(obs)

        ########### Process features to get action ##########
        # Dueling Setup
        #val = self.val(h)
        logits = self.adv(h)
        #logits = adv  # val + adv - adv.mean()

        if return_only_action and not noise:
            return logits.argmax(1)

        probs = torch.softmax(logits, dim=1)
        probs = torch.clamp(probs, min=0.001, max=0.99)

        dist = Categorical(probs)
        action = dist.sample()

        # temp = self.temperature(h) * 100
        # temp = torch.clamp(temp, min=self.TEMP_MIN, max=self.TEMP_MAX)

        # dist = RelaxedOneHotCategorical(probs=probs, temperature=temp)
        # sampled_logits = dist.rsample()  # for reparameterization trick

        if return_only_action:  # Noisy Sample
            epsilon = random.random()
            if epsilon < self.epsilon.item():
                action = torch.Tensor([random.randint(0, self.action_dim - 1)]).int().unsqueeze(0)
                if self.epsilon > self.epsilon_end:
                    self.epsilon -= self.epsilon_decay_rate

            # print(probs.detach().numpy(), 'Eps', self.epsilon.item(), 'Same' if not (action - logits.argmax(1)).detach().item() else 'Different')
            return action

        log_prob = F.log_softmax(logits, dim=1)  # (probs * F.log_softmax(probs, dim=1)).sum(axis=1).unsqueeze(1)

        return action, logits, log_prob, probs


class QGlobal(nn.Module):
    def __init__(self, num_inputs, action_dim, hidden_size, num_agents, feature_topology=[256, 256],
                 feature_type='fnn'):
        super(QGlobal, self).__init__()

        self.num_agents = num_agents

        # Feature Extractor Nets
        self.feature_processor = [FeatureExtractor(num_inputs, feature_topology, feature_type).to(device='cuda') for _
                                  in range(num_agents)]

        # Q1 architecture
        self.q1_l1 = nn.Linear(feature_topology[-1] * num_agents + action_dim * num_agents, hidden_size)
        self.q1_l2 = nn.Linear(hidden_size, 1)

        # # Q2 architecture
        self.q2_l1 = nn.Linear(feature_topology[-1] * num_agents + action_dim * num_agents, hidden_size)
        self.q2_l2 = nn.Linear(hidden_size, 1)

    def forward(self, obs, action):
        """
        obs --> [batch_size, agent_id, *]
        action --> [batch_size, agent_id, *]
        """

        ###### Process Features to compute h  ####
        h = [self.feature_processor[agent_id].forward(obs[:, agent_id:agent_id + 1, :].squeeze(1)) for agent_id in
             range(self.num_agents)]
        h = torch.cat(h, axis=1)

        h = torch.cat([h, action], axis=1)

        q1 = torch.selu(self.q1_l1(h))
        q1 = self.q1_l2(q1)

        q2 = torch.selu(self.q2_l1(h))
        q2 = self.q2_l2(q2)

        return q1, q2
