import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.distributions.normal import Normal
import numpy as np


class RNNAgent(nn.Module):
    def __init__(self, input_shape, args):
        super(RNNAgent, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state):
        x = F.relu(self.fc1(inputs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        return q, h



LOG_STD_MAX = 2
LOG_STD_MIN = -20

class RNNAgent_Gaussian(nn.Module):
    def __init__(self, input_shape, args):
        super(RNNAgent_Gaussian, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)

        self.fc3_mu = nn.Linear(args.rnn_hidden_dim, args.n_actions)
        self.fc3_log_sigma = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        self.agent_return_logits = getattr(self.args, "agent_return_logits", False)
        # pdb.set_trace()
        # self.args.action_spaces[0].high[0]
        self.act_limit = self.args.action_spaces[0].high[0]  #self.args.action_spaces

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, actions=None):
        x = F.relu(self.fc1(inputs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        x2 = F.relu(self.fc2(h))

        mu = self.fc3_mu(x2)
        log_sigma = self.fc3_log_sigma(x2 + 1E-6)
        sigma = torch.exp(torch.clamp(log_sigma, LOG_STD_MIN, LOG_STD_MAX))
        pi_distribution = Normal(mu, sigma)
        pi_action = pi_distribution.rsample()

        logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
        logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action
        mu_action = self.act_limit * torch.tanh(mu)

        return {"actions": pi_action, "mu_actions": mu_action, "logp_pi":logp_pi, "hidden_state": h, "log_sigma":log_sigma}
