import numpy as np
import scipy.signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit
        with torch.no_grad():
            self.rms_sum = torch.zeros(obs_dim, dtype=torch.float32) 
            self.rms_sumsq = torch.tensor(np.ones(obs_dim) * 1e-2, dtype=torch.float32)
            self.rms_count = torch.tensor(1e-2, dtype=torch.float32)
            self.rms_eps = torch.tensor(1e-2, dtype=torch.float32)

            self.rms_mean = self.rms_sum / self.rms_count
            self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
            

    def forward(self, obs, deterministic=False, with_logprob=True):
        obs = torch.clamp((obs - self.rms_mean) / self.rms_std, min=-5.0, max=5.0)
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi
    
    def update(self, x):
        self.rms_sum = self.rms_sum +  x
        self.rms_sumsq = self.rms_sumsq +  x * x
        self.rms_count = torch.add(self.rms_count, 1.0)

        self.rms_mean =  self.rms_sum / self.rms_count
        self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
        


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)
        with torch.no_grad():
            self.rms_sum = torch.zeros(obs_dim, dtype=torch.float32) 
            self.rms_sumsq = torch.tensor(np.ones(obs_dim) * 1e-2, dtype=torch.float32)
            self.rms_count = torch.tensor(1e-2, dtype=torch.float32)
            self.rms_eps = torch.tensor(1e-2, dtype=torch.float32)

            self.rms_mean = self.rms_sum / self.rms_count
            self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
            

    def forward(self, obs, act):
        obs = torch.clamp((obs - self.rms_mean) / self.rms_std, min=-5.0, max=5.0)
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

    def update(self, x):
        self.rms_sum = self.rms_sum +  x
        self.rms_sumsq = self.rms_sumsq +  x * x
        self.rms_count = torch.add(self.rms_count, 1.0)

        self.rms_mean =  self.rms_sum / self.rms_count
        self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
        

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

    def update(self, x):
        with torch.no_grad():
            x = torch.tensor(x, dtype=torch.float32)
            self.pi.update(x)
            self.q1.update(x)
            self.q2.update(x)