import math

from typing import (
    Optional,
    Union,
)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.normal import Normal


def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.0)


def _choose_head(out: torch.tensor, obs: torch.tensor, num_heads: int) -> torch.tensor:
    """For multi-head output, choose appropriate head.

    We assume that task number is one-hot encoded as a part of observation.

    Args:
      out: multi-head output tensor from the model
      obs: obsevation batch. We assume that last num_heads dims is one-hot encoding of task
      num_heads: number of heads

    Returns:
      torch.tensor: output for the appropriate head
    """
    batch_size = out.shape[0]
    out = out.view(batch_size, -1, num_heads)
    obs = obs[:, -num_heads:].view(batch_size, num_heads, 1)
    return torch.squeeze(out @ obs, axis=2)


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity, use_layer_norm: bool = False):
    layers = []
    for j in range(len(sizes) - 1):
        if j == 0 and use_layer_norm:
            layers += [nn.Linear(sizes[j], sizes[j + 1]), nn.LayerNorm(sizes[j + 1]), nn.Tanh()]
            continue
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20


class SquashedGaussianMLPActor(nn.Module):
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        activation,
        act_limit,
        num_heads: int = 1,
        hide_task_id: bool = False,
        use_layer_norm: bool = False,
        one_hot_len: int = 0,
    ):
        super().__init__()
        assert not hide_task_id or one_hot_len > 0

        self.num_heads = num_heads
        self.hide_task_id = hide_task_id
        self.one_hot_len = one_hot_len

        self.net = mlp(
            [obs_dim] + list(hidden_sizes), activation, output_activation=activation, use_layer_norm=use_layer_norm
        )
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim * self.num_heads)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim * self.num_heads)
        self.act_limit = act_limit

    def forward(self, x, deterministic=False, with_logprob=True, return_dist=False):
        obs = x

        if self.hide_task_id:
            x = x[:, : -self.one_hot_len]

        net_out = self.net(x)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)

        if self.num_heads > 1:
            mu = _choose_head(mu, obs, self.num_heads)
            log_std = _choose_head(log_std, obs, self.num_heads)

        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290)
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        if return_dist:
            return pi_action, logp_pi, mu, log_std
        else:
            return pi_action, logp_pi


class MLPQFunction(nn.Module):
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        activation,
        num_heads: int = 1,
        hide_task_id: bool = False,
        use_layer_norm: bool = False,
        one_hot_len: int = 0,
    ):
        super().__init__()
        assert not hide_task_id or one_hot_len > 0

        # We potentially might want to implement heads as a separate module
        self.num_heads = num_heads
        self.hide_task_id = hide_task_id
        self.one_hot_len = one_hot_len
        self.q = mlp(
            [obs_dim + act_dim] + list(hidden_sizes) + [self.num_heads], activation, use_layer_norm=use_layer_norm
        )

    def forward(self, x, act):
        obs = x
        if self.hide_task_id:
            x = x[:, : -self.one_hot_len]
        q = self.q(torch.cat([x, act], dim=-1))
        if self.num_heads > 1:
            q = _choose_head(q, obs, self.num_heads)
        return torch.squeeze(q, -1)  # Critical to ensure q has right shape.


class MLPActorCritic(nn.Module):
    def __init__(
        self,
        observation_space,
        action_space,
        hidden_sizes=(256, 256),
        activation=nn.ReLU,
        hide_task_id: bool = False,
        num_heads: int = 1,
        use_layer_norm=False,
        one_hot_len: int = 0,
        alpha_init: Union[float, str] = 0.2,
        target_output_std: Optional[float] = None,
    ):
        super().__init__()

        obs_dim = observation_space.shape[0]
        if hide_task_id:
            assert one_hot_len != 0, "We cannot task ID if there's no task ID"
            obs_dim -= one_hot_len

        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]
        self.one_hot_len = one_hot_len

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(
            obs_dim,
            act_dim,
            hidden_sizes,
            activation,
            act_limit,
            num_heads=num_heads,
            hide_task_id=hide_task_id,
            use_layer_norm=use_layer_norm,
            one_hot_len=one_hot_len,
        )
        self.q1 = MLPQFunction(
            obs_dim,
            act_dim,
            hidden_sizes,
            activation,
            num_heads=num_heads,
            hide_task_id=hide_task_id,
            use_layer_norm=use_layer_norm,
            one_hot_len=one_hot_len,
        )
        self.q2 = MLPQFunction(
            obs_dim,
            act_dim,
            hidden_sizes,
            activation,
            num_heads=num_heads,
            hide_task_id=hide_task_id,
            use_layer_norm=use_layer_norm,
            one_hot_len=one_hot_len,
        )

        if alpha_init == "auto":
            self.auto_alpha = True

            if one_hot_len == 0:
                self.all_log_alpha = nn.Parameter(torch.ones(1))
            else:
                self.all_log_alpha = nn.Parameter(torch.ones((one_hot_len, 1)))
            if target_output_std is None:
                self.target_entropy = -np.prod(action_space.shape).item()
            else:
                target_1d_entropy = np.log(target_output_std * math.sqrt(2 * math.pi * math.e)).item()
                self.target_entropy = np.prod(action_space.shape).item() * target_1d_entropy
        else:
            self.auto_alpha = False
            self.all_log_alpha = torch.tensor(alpha_init).log()

    def reset_weights(self):
        self.apply(init_weights)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

    def get_alpha(self, obs):
        if self.one_hot_len == 0 or not self.auto_alpha:
            return self.all_log_alpha.exp(), self.all_log_alpha
        else:
            log_alpha = torch.squeeze(torch.matmul(obs[:, -self.one_hot_len :], self.all_log_alpha))
            return log_alpha.exp(), log_alpha
