import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete
from mpi4py import MPI
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical

# TRPO utilities
def flat_grads(grads):
    return torch.cat([grad.contiguous().view(-1) for grad in grads])


def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))

    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


class Actor(nn.Module):

    def _distribution(self, obs):
        raise NotImplementedError

    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError

    def forward(self, obs, act=None):
        # Produce action distributions for given observations, and 
        # optionally compute the log likelihood of given actions under
        # those distributions.
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a


class MLPCategoricalActor(Actor):
    
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)


class MLPGaussianActor(Actor):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.obs_dim = obs_dim
        log_std = -0.08 * np.ones(act_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
        self.pi = torch.tensor(np.pi, dtype=torch.float32)
        self.e = torch.tensor(np.e, dtype=torch.float32)
        
        with torch.no_grad():
            self.rms_sum = torch.zeros(obs_dim, dtype=torch.float32) 
            self.rms_sumsq = torch.tensor(np.ones(obs_dim) * 1e-2, dtype=torch.float32)
            self.rms_count = torch.tensor(1e-2, dtype=torch.float32)
            self.rms_eps = torch.tensor(1e-2, dtype=torch.float32)


            self.rms_mean = self.rms_sum / self.rms_count
            self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
            # self.rms_mean[17:] = torch.zeros(25, dtype=torch.float32)
            # self.rms_std[17:] = torch.ones(25, dtype=torch.float32)
            


    def _distribution(self, obs):
        obs = torch.clamp((obs - self.rms_mean) / self.rms_std, min=-5.0, max=5.0)
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(axis=-1)    # Last axis sum needed for Torch Normal distribution
    
    def entropy(self):
        return torch.sum(self.log_std + .5*torch.log(2.0 * self.pi * self.e))

    def update(self, x, n):
        self.rms_sum = self.rms_sum +  torch.tensor(x[0:n], dtype=torch.float32)
        self.rms_sumsq = self.rms_sumsq +  torch.tensor(x[n:2*n], dtype=torch.float32)
        self.rms_count = self.rms_count + torch.tensor(x[2*n], dtype=torch.float32)

        self.rms_mean =  self.rms_sum / self.rms_count
        self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
        # self.rms_mean[17:] = torch.zeros(25, dtype=torch.float32)
        # self.rms_std[17:] = torch.ones(25, dtype=torch.float32)

class MLPCritic(nn.Module):

    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.obs_dim = obs_dim
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

        with torch.no_grad():
            self.rms_sum = torch.zeros(obs_dim, dtype=torch.float32) 
            self.rms_sumsq = torch.tensor(np.ones(obs_dim) * 1e-2, dtype=torch.float32)
            self.rms_count = torch.tensor(1e-2, dtype=torch.float32)
            self.rms_eps = torch.tensor(1e-2, dtype=torch.float32)

            self.rms_mean = self.rms_sum / self.rms_count
            self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
            # self.rms_mean[17:] = torch.zeros(25, dtype=torch.float32)
            # self.rms_std[17:] = torch.ones(25, dtype=torch.float32)

    def forward(self, obs):
        obs = torch.clamp((obs - self.rms_mean) / self.rms_std, min=-5.0, max=5.0)
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.

    def update(self, x, n):
        self.rms_sum = self.rms_sum +  torch.tensor(x[0:n], dtype=torch.float32)
        self.rms_sumsq = self.rms_sumsq +  torch.tensor(x[n:2*n], dtype=torch.float32)
        self.rms_count = self.rms_count + torch.tensor(x[2*n], dtype=torch.float32)

        self.rms_mean =  self.rms_sum / self.rms_count
        self.rms_std = torch.sqrt(torch.max((self.rms_sumsq / self.rms_count) - torch.pow(self.rms_mean, 2), self.rms_eps))
        # self.rms_mean[17:] = torch.zeros(25, dtype=torch.float32)
        # self.rms_std[17:] = torch.ones(25, dtype=torch.float32)

class MLPActorCritic(nn.Module):
    def __init__(self, observation_space, action_space, 
                 hidden_sizes=(64,64), activation=nn.Tanh):
        super().__init__()

        obs_dim = observation_space.shape[0]
        self.obs_dim = obs_dim

        # policy builder depends on action space
        if isinstance(action_space, Box):
            self.pi = MLPGaussianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
        elif isinstance(action_space, Discrete):
            self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)

        # build value function
        self.v  = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs):
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            logp_a = self.pi._log_prob_from_distribution(pi, a)
            v = self.v(obs)
        return a.numpy(), v.numpy(), logp_a.numpy()

    def act(self, obs):
        return self.step(obs)[0]

    def update(self, x):
        with torch.no_grad():
            x = x.numpy().astype('float64')
            n = int(self.obs_dim)
            totalvec = np.zeros(n*2+1, dtype=np.float64)
            addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype=np.float32)])

            MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
            self.pi.update(totalvec, n)
            self.v.update(totalvec, n)