from abc import ABCMeta, abstractmethod
from typing import Tuple, cast

import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence

from .encoders import Encoder, EncoderWithAction



class DropImitator(nn.Module, metaclass=ABCMeta):  # type: ignore
    @abstractmethod
    def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        pass

    def __call__(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return cast(torch.Tensor, super().__call__(x, e))

    @abstractmethod
    def compute_error(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        pass


class DropDiscreteImitator(DropImitator):
    _encoder: Encoder
    _beta: float
    # _fc: nn.Linear

    def __init__(self, encoder: Encoder, action_size: int, embedding_size: int, beta: float):
        super().__init__()
        self._encoder = encoder
        self._beta = beta
        # self._fc = nn.Linear(encoder.get_feature_size(), action_size)
        self._fc = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 256),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(256, 256),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(256, action_size),
        )

    def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        return self.compute_log_probs_with_logits(x, e)[0]

    def compute_log_probs_with_logits(
        self, x: torch.Tensor, e: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        logits = self._fc(h)
        log_probs = F.log_softmax(logits, dim=1)
        return log_probs, logits

    def compute_error(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        log_probs, logits = self.compute_log_probs_with_logits(x, e)
        penalty = (logits**2).mean()
        return F.nll_loss(log_probs, action.view(-1)) + self._beta * penalty


class DropDeterministicRegressor(DropImitator):
    _encoder: Encoder
    # _fc: nn.Linear

    def __init__(self, encoder: Encoder, action_size: int, embedding_size: int):
        super().__init__()
        self._encoder = encoder
        
        # self._fc = nn.Linear(encoder.get_feature_size(), action_size)
        self._fc = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 512),
            # nn.Linear(256, 256),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, 512),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(512, action_size),
        )

    def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        h = self._fc(h)
        return torch.tanh(h)

    def compute_error(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        return F.mse_loss(self.forward(x, e), action)


class DropProbablisticRegressor(DropImitator):
    _min_logstd: float
    _max_logstd: float
    _encoder: Encoder
    _mu: nn.Linear
    _logstd: nn.Linear

    def __init__(
        self,
        encoder: Encoder,
        action_size: int,
        embedding_size: int,
        min_logstd: float,
        max_logstd: float,
    ):
        super().__init__()
        self._min_logstd = min_logstd
        self._max_logstd = max_logstd
        self._encoder = encoder
        # self._mu = nn.Linear(encoder.get_feature_size(), action_size)
        # self._logstd = nn.Linear(encoder.get_feature_size(), action_size)
        self._mu = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 512),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, 512),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(512, action_size),
        )
        self._logstd = nn.Sequential(
            nn.Linear(encoder.get_feature_size()+embedding_size, 512),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, 512),
            # nn.Dropout(0.2),
            # nn.ReLU(),
            nn.Linear(512, action_size),
        )

    def dist(self, x: torch.Tensor, e: torch.Tensor) -> Normal:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        mu = self._mu(h)
        logstd = self._logstd(h)
        clipped_logstd = logstd.clamp(self._min_logstd, self._max_logstd)
        return Normal(mu, clipped_logstd.exp())

    def forward(self, x: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
        h = self._encoder(x)
        h = torch.cat([h, e], dim=1)
        mu = self._mu(h)
        return torch.tanh(mu)

    def sample_n(self, x: torch.Tensor, e: torch.Tensor, n: int) -> torch.Tensor:
        dist = self.dist(x, e)
        actions = cast(torch.Tensor, dist.rsample((n,)))
        # (n, batch, action) -> (batch, n, action)
        return actions.transpose(0, 1)

    def compute_error(
        self, x: torch.Tensor, action: torch.Tensor, e: torch.Tensor
    ) -> torch.Tensor:
        dist = self.dist(x, e)
        return F.mse_loss(torch.tanh(dist.rsample()), action)
