import numpy as np
import scipy.signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from copy import deepcopy
import gym


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


class BaseFeaturesExtractor(nn.Module):
    """
    Base class that represents a features extractor.
    :param observation_space:
    :param features_dim: Number of features extracted.
    """

    def __init__(self, observation_space: gym.Space, features_dim: int = 0):
        super(BaseFeaturesExtractor, self).__init__()
        assert features_dim > 0
        self._observation_space = observation_space
        self._features_dim = features_dim

    @property
    def features_dim(self) -> int:
        return self._features_dim

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class NatureCNN(BaseFeaturesExtractor):
    """
    CNN from DQN nature paper:
        Mnih, Volodymyr, et al.
        "Human-level control through deep reinforcement learning."
        Nature 518.7540 (2015): 529-533.
    :param observation_space:
    :param features_dim: Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(NatureCNN, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))


class CategoricalMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.logits_layer = nn.Linear(hidden_sizes[-1], act_dim)

    def forward(self, obs, deterministic=False):
        net_out = self.net(obs)
        logits = self.logits_layer(net_out)

        action_probs = F.softmax(logits, dim=-1)
        log_action_probs = F.log_softmax(logits, dim=-1)

        # Pre-squash distribution and sample
        pi_distribution = Categorical(logits=logits)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = torch.argmax(logits, dim=-1)
        else:
            pi_action = pi_distribution.sample()

        return pi_action, action_probs, log_action_probs


class CategoricalCNNActor(nn.Module):

    def __init__(self, cnn_feature_extractor, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.cnn = cnn_feature_extractor
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.logits_layer = nn.Linear(hidden_sizes[-1], act_dim)

    def forward(self, obs, deterministic=False):
        feats = self.cnn(obs)
        net_out = self.net(feats)
        logits = self.logits_layer(net_out)

        action_probs = F.softmax(logits, dim=-1)
        log_action_probs = F.log_softmax(logits, dim=-1)

        # Pre-squash distribution and sample
        pi_distribution = Categorical(logits=logits)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = torch.argmax(logits, dim=-1)
        else:
            pi_action = pi_distribution.sample()

        return pi_action, action_probs, log_action_probs


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def forward(self, obs):
        q = self.q(obs)
        return q


class CNNQFunction(nn.Module):

    def __init__(self, cnn_feature_extractor, obs_dim, act_dim, hidden_sizes, activation, share_extractor=True):
        super().__init__()
        self.share_extractor = share_extractor
        self.cnn = cnn_feature_extractor
        self.q = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def forward(self, obs):
        if self.share_extractor:
            with torch.no_grad():
                feats = self.cnn(obs)
        else:
            feats = self.cnn(obs)
        q = self.q(feats)
        return q


class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256), activation=nn.ReLU):
        super().__init__()

        assert len(observation_space.shape) == 1, (
            "Obs-space seems to be a matrix. You should probably use a CNNActorCritic "
            f"with observations of type {observation_space}\n"
            "(you are using `MLPActorCritic`)\n"
        )

        obs_dim = observation_space.shape[0]
        act_dim = action_space.n

        # build policy and value functions
        self.pi = CategoricalMLPActor(obs_dim, act_dim, hidden_sizes, activation)
        self.q1_logits = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2_logits = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _, _ = self.pi(obs, deterministic)
            return a.numpy()


class CNNActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256), activation=nn.ReLU, cnn_featdim=512, share_extractor=True):
        super().__init__()

        assert len(observation_space.shape) > 1, (
            "You should use CNNActorCritic "
            f"only with images not with {observation_space}\n"
            "(you should probably be using `MLPActorCritic`)\n"
        )

        obs_dim = observation_space.shape[0]
        act_dim = action_space.n

        self.share_extractor = share_extractor
        self.CNNExtractor = NatureCNN(observation_space, cnn_featdim)

        # build policy and value functions
        self.pi = CategoricalCNNActor(self.CNNExtractor, cnn_featdim, act_dim, hidden_sizes, activation)

        if not share_extractor:
            self.q_CNNExtractor = deepcopy(self.CNNExtractor)
        else: 
            self.q_CNNExtractor = self.CNNExtractor

        self.q1_logits = CNNQFunction(self.q_CNNExtractor, cnn_featdim, act_dim, hidden_sizes, activation, share_extractor)
        self.q2_logits = CNNQFunction(self.q_CNNExtractor, cnn_featdim, act_dim, hidden_sizes, activation, share_extractor)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            feats = self.CNNExtractor(obs)
            a, _, _ = self.pi(feats, deterministic)
            return a.numpy()
