import torch.nn as nn
import torch.nn.functional as F
import torch as th
from torch.distributions import Normal
import numpy as np

MEAN_MIN = -9.0
MEAN_MAX = 9.0
LOG_STD_MIN = -5
LOG_STD_MAX = 2
EPS = 1e-7

class RNNContAgent(nn.Module):
    def __init__(self, input_shape, args):
        super(RNNContAgent, self).__init__()
        self.args = args
        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim)

        self.mu_head = nn.Linear(args.rnn_hidden_dim, args.n_actions)
        self.sigma_head = nn.Linear(args.rnn_hidden_dim, args.n_actions)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state):
        x = F.relu(self.fc1(inputs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        x = self.fc2(h)

        mu = self.mu_head(x)
        log_sigma = self.sigma_head(x)
        log_sigma = th.clamp(log_sigma, LOG_STD_MIN, LOG_STD_MAX)
        sigma = th.exp(log_sigma)

        a_distribution = Normal(mu, sigma)
        action = a_distribution.rsample()

        logp_pi = a_distribution.log_prob(action).sum(axis=-1)
        logp_pi -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(axis=-1)

        action = th.tanh(action)

        mu = th.clamp(mu, MEAN_MIN, MEAN_MAX)
        mu = th.tanh(mu)

        return action, logp_pi, a_distribution, mu, h
    
    def get_log_density(self, x, y, hidden_state):
        x = F.relu(self.fc1(x))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        x = self.fc2(h)

        mu = self.mu_head(x)
        log_sigma = self.sigma_head(x)

        y = th.clamp(y, -1. + EPS, 1. - EPS)
        y = th.atanh(y)

        mu = th.clamp(mu, MEAN_MIN, MEAN_MAX)
        log_sigma = th.clamp(log_sigma, LOG_STD_MIN, LOG_STD_MAX)
        sigma = th.exp(log_sigma)

        a_distribution = Normal(mu, sigma)
        logp_pi = a_distribution.log_prob(y).sum(axis=-1)
        logp_pi -= (2 * (np.log(2) - y - F.softplus(-2 * y))).sum(axis=-1)

        return logp_pi, h
