import math
from abc import ABCMeta, abstractmethod
from typing import Tuple, Union, cast

import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical

from .distributions import GaussianDistribution, SquashedGaussianDistribution
from .encoders import Encoder, EncoderWithAction


class DropPolicy(nn.Module, metaclass=ABCMeta):  # type: ignore
    def sample(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return self.sample_with_log_prob(x, e)[0]

    @abstractmethod
    def sample_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        pass

    def sample_n(self, x: torch.Tensor, e: torch.Tensor, n: int) -> torch.Tensor:
        return self.sample_n_with_log_prob(x, e, n)[0]

    @abstractmethod
    def sample_n_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor, n: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        pass

    @abstractmethod
    def best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        pass


class DropDeterministicPolicy(DropPolicy):

    _encoder: Encoder
    # _fc: nn.Linear

    def __init__(self, encoder: Encoder, action_size: int, embedding_size: int):
        super().__init__()
        self._encoder = encoder
        # self._fc = nn.Linear(encoder.get_feature_size(), action_size)
        self._fc = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 256),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, 256),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(256, action_size),
        )

    def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        return torch.tanh(self._fc(h))

    def __call__(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, super().__call__(x, e))

    def sample_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError(
            "deterministic policy does not support sample"
        )

    def sample_n_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor, n: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError(
            "deterministic policy does not support sample_n"
        )

    def best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return self.forward(x, e)


class DropNormalPolicy(DropPolicy):

    _encoder: Encoder
    _action_size: int
    _min_logstd: float
    _max_logstd: float
    _use_std_parameter: bool
    _mu: nn.Linear
    _logstd: Union[nn.Linear, nn.Parameter]

    def __init__(
        self,
        encoder: Encoder,
        action_size: int,
        embedding_size: int,
        min_logstd: float,
        max_logstd: float,
        use_std_parameter: bool,
        squash_distribution: bool,
    ):
        super().__init__()
        self._action_size = action_size
        self._embedding_size = embedding_size
        self._encoder = encoder
        self._min_logstd = min_logstd
        self._max_logstd = max_logstd
        self._use_std_parameter = use_std_parameter
        self._squash_distribution = squash_distribution
        # self._mu = nn.Linear(encoder.get_feature_size(), action_size)
        self._mu = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 256),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, 256),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(256, action_size),
        )
        if use_std_parameter:
            initial_logstd = torch.zeros(1, action_size, dtype=torch.float32)
            self._logstd = nn.Parameter(initial_logstd)
        else:
            # self._logstd = nn.Linear(encoder.get_feature_size(), action_size)
            self._logstd = nn.Sequential(
                nn.Linear(encoder.get_feature_size()+embedding_size, 256),
                # nn.Dropout(0.2),
                nn.ReLU(),
                nn.Linear(256, 256),
                # nn.Dropout(0.2),
                # nn.ReLU(),
                nn.Linear(256, action_size),
            )

    def _compute_logstd(self, h: torch.Tensor) -> torch.Tensor:
        if self._use_std_parameter:
            clipped_logstd = self.get_logstd_parameter()
        else:
            logstd = cast(nn.Linear, self._logstd)(h)
            clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd)
        return clipped_logstd

    def dist(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Union[GaussianDistribution, SquashedGaussianDistribution]:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        mu = self._mu(h)
        clipped_logstd = self._compute_logstd(h)
        if self._squash_distribution:
            return SquashedGaussianDistribution(mu, clipped_logstd.exp())
        else:
            return GaussianDistribution(
                torch.tanh(mu),
                clipped_logstd.exp(),
                raw_loc=mu,
            )

    def forward(
        self,
        x: torch.Tensor,
        e: torch.Tensor,
        deterministic: bool = False,
        with_log_prob: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        dist = self.dist(x, e)
        if deterministic:
            action, log_prob = dist.mean_with_log_prob()
        else:
            action, log_prob = dist.sample_with_log_prob()
        return (action, log_prob) if with_log_prob else action

    def sample_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        out = self.forward(x, e, with_log_prob=True)
        return cast(Tuple[torch.Tensor, torch.Tensor], out)

    def sample_n_with_log_prob(
        self,
        x: torch.Tensor,
        e: torch.Tensor,
        n: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.dist(x, e)

        action_T, log_prob_T = dist.sample_n_with_log_prob(n)

        # (n, batch, action) -> (batch, n, action)
        transposed_action = action_T.transpose(0, 1)
        # (n, batch, 1) -> (batch, n, 1)
        log_prob = log_prob_T.transpose(0, 1)

        return transposed_action, log_prob

    def sample_n_without_squash(self, x: torch.Tensor, e: torch.Tensor, n: int) -> torch.Tensor:
        dist = self.dist(x, e)
        action = dist.sample_n_without_squash(n)
        return action.transpose(0, 1)

    def onnx_safe_sample_n(self, x: torch.Tensor, e: torch.Tensor, n: int) -> torch.Tensor:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        mean = self._mu(h)
        std = self._compute_logstd(h).exp()

        if not self._squash_distribution:
            mean = torch.tanh(mean)

        # expand shape
        # (batch_size, action_size) -> (batch_size, N, action_size)
        expanded_mean = mean.view(-1, 1, self._action_size).repeat((1, n, 1))
        expanded_std = std.view(-1, 1, self._action_size).repeat((1, n, 1))

        # sample noise from Gaussian distribution
        noise = torch.randn(x.shape[0], n, self._action_size, device=x.device)

        if self._squash_distribution:
            return torch.tanh(expanded_mean + noise * expanded_std)
        else:
            return expanded_mean + noise * expanded_std

    def best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        action = self.forward(x, e, deterministic=True, with_log_prob=False)
        return cast(torch.Tensor, action)

    def get_logstd_parameter(self) -> torch.Tensor:
        assert self._use_std_parameter
        logstd = torch.sigmoid(cast(nn.Parameter, self._logstd))
        base_logstd = self._max_logstd - self._min_logstd
        return self._min_logstd + logstd * base_logstd


class DropSquashedNormalPolicy(DropNormalPolicy):
    def __init__(
        self,
        encoder: Encoder,
        action_size: int,
        embedding_size: int,
        min_logstd: float,
        max_logstd: float,
        use_std_parameter: bool,
    ):
        super().__init__(
            encoder=encoder,
            action_size=action_size,
            embedding_size=embedding_size,
            min_logstd=min_logstd,
            max_logstd=max_logstd,
            use_std_parameter=use_std_parameter,
            squash_distribution=True,
        )


class DropNonSquashedNormalPolicy(DropNormalPolicy):
    def __init__(
        self,
        encoder: Encoder,
        action_size: int,
        embedding_size: int,
        min_logstd: float,
        max_logstd: float,
        use_std_parameter: bool,
    ):
        super().__init__(
            encoder=encoder,
            action_size=action_size,
            embedding_size=embedding_size,
            min_logstd=min_logstd,
            max_logstd=max_logstd,
            use_std_parameter=use_std_parameter,
            squash_distribution=False,
        )


class DropCategoricalPolicy(DropPolicy):

    _encoder: Encoder
    # _fc: nn.Linear

    def __init__(self, encoder: Encoder, action_size: int, embedding_size: int):
        super().__init__()
        self._encoder = encoder
        # self._fc = nn.Linear(encoder.get_feature_size(), action_size)
        self._fc = nn.Sequential(
                nn.Linear(encoder.get_feature_size()+embedding_size, 256),
                # nn.Dropout(0.2),
                nn.ReLU(),
                nn.Linear(256, 256),
                # nn.Dropout(0.2),
                # nn.ReLU(),
                nn.Linear(256, action_size),
            )

    def dist(self, x: torch.Tensor, e: torch.Tensor) -> Categorical:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        h = self._fc(h)
        return Categorical(torch.softmax(h, dim=1))

    def forward(
        self,
        x: torch.Tensor,
        e: torch.Tensor,
        deterministic: bool = False,
        with_log_prob: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        dist = self.dist(x, e)

        if deterministic:
            action = cast(torch.Tensor, dist.probs.argmax(dim=1))
        else:
            action = cast(torch.Tensor, dist.sample())

        if with_log_prob:
            return action, dist.log_prob(action)

        return action

    def sample_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        out = self.forward(x, e, with_log_prob=True)
        return cast(Tuple[torch.Tensor, torch.Tensor], out)

    def sample_n_with_log_prob(
        self, x: torch.Tensor, e: torch.Tensor, n: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.dist(x, e)

        action_T = cast(torch.Tensor, dist.sample((n,)))
        log_prob_T = dist.log_prob(action_T)

        # (n, batch) -> (batch, n)
        action = action_T.transpose(0, 1)
        # (n, batch) -> (batch, n)
        log_prob = log_prob_T.transpose(0, 1)

        return action, log_prob

    def best_action(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, self.forward(x, e, deterministic=True))

    def log_probs(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        dist = self.dist(x, e)
        return cast(torch.Tensor, dist.logits)
