import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical

from core.networks.base_actor import BaseActor
from core.utils.sac_utils import initialize_last_layer, initialize_hidden_layer, TanhNormal
from core.networks.mlp import MLP

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20


class ContinuousStochasticActor(BaseActor):
    def __init__(self, observation_dim, action_dim, max_action, layers_dim, device):
        super(ContinuousStochasticActor, self).__init__()

        assert len(layers_dim) > 1, 'can not define continuous stochastic actor without hidden layers'
        # define shared layers
        self.shared_body = MLP(observation_dim, layers_dim[-1], layers_dim[:-1])
        for layer in self.shared_body.layers:
            initialize_hidden_layer(layer)
        # define and initialize output layers (mean and log std)
        self.mean = nn.Linear(layers_dim[-1], action_dim)
        self.log_std = nn.Linear(layers_dim[-1], action_dim)
        initialize_last_layer(self.mean)
        initialize_last_layer(self.log_std)

        self.max_action = max_action
        self.device = device

    def forward(self, x, deterministic=False, return_log_prob=False):
        # forward pass
        x = self.shared_body(x)
        mean, log_std = self.mean(x), self.log_std(x)

        # compute std
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        # compute other relevant quantities
        log_prob, entropy, mean_action_log_prob, pre_tanh_value = None, None, None, None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std, self.device)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.rsample(return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                action = tanh_normal.rsample()
        action = action * self.max_action
        return action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value

    def select_action(self, observation, deterministic):
        # observation = torch.FloatTensor(observation.reshape(1, -1)).to(device)
        return self(observation, deterministic=deterministic)[0].cpu().data.numpy().flatten()


class ContinuousDeterministicActor(BaseActor):
    def __init__(self, observation_dim, action_dim, max_action, expl_noise, layers_dim):
        super(ContinuousDeterministicActor, self).__init__()
        self.core_mlp = MLP(observation_dim, action_dim, layers_dim)
        layers = self.core_mlp.layers
        for layer in layers[:-1]:
            initialize_hidden_layer(layer)
        initialize_last_layer(layers[-1])

        self.max_action = max_action
        self.expl_noise = expl_noise
        self.action_dim = action_dim

    def forward(self, x):
        x = self.core_mlp(x)
        x = self.max_action * torch.tanh(x)
        return x

    def select_action(self, observation, deterministic=True):
        action = self(observation).cpu().data.numpy().flatten()
        if not deterministic:
            action = (action + np.random.normal(0, self.expl_noise, size=self.action_dim)).clip(-self.max_action, self.max_action)
        return action
