import torch
from torch import nn
from utils.generate_nn import generate_MLP

from torch import distributions as D
from functools import partial


def weight_init(m: nn.Module, gain: int = 1) -> None:
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data, gain=gain)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)


class SquashedNormal(D.TransformedDistribution):
    def __init__(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self._loc = loc
        self.scale = scale
        self.base_dist = D.Normal(loc, scale)
        transforms = [D.transforms.TanhTransform(cache_size=1)]
        super().__init__(self.base_dist, transforms)

    @property
    def loc(self) -> torch.Tensor:
        loc = self._loc
        for transform in self.transforms:
            loc = transform(loc)
        return loc

    @property
    def mean(self) -> torch.Tensor:
        return self.loc

    @property
    def stddev(self) -> torch.Tensor:
        return self.scale

    def log_prob(self, value):
        return super().log_prob(value).mean(-1)

    def estimate_entropy(dist: D.TransformedDistribution, n_samples: int = 100):
        samples = dist.sample(
            (n_samples,)
        )  # shape: [n_samples, batch_size, action_dim]
        with torch.no_grad():
            log_probs = dist.log_prob(samples)  # shape: [n_samples, batch_size]
        entropy_estimate = -log_probs.mean(0)  # mean over samples, shape: [batch_size]
        return entropy_estimate.mean()


class DiagGaussianActor(nn.Module):
    """torch.distributions implementation of an diagonal Gaussian policy."""

    def __init__(
        self,
        obs_dim,
        action_dim,
        hidden_dim,
        hidden_depth,
        init_std: float = 1.0,
        ortho_init: bool = True,
        dropout: float = 0.0,
        **ignore
    ):
        super().__init__()

        self.std_value = torch.tensor(init_std)
        self.action_dim = action_dim

        self.trunk = generate_MLP(
            in_dim=obs_dim,
            out_dim=self.action_dim,
            width=hidden_dim,
            n_layers=hidden_depth,
            dropout=dropout,
        )

        self.ortho_init = ortho_init

        self.reset_parameters()

    def reset_parameters(self):
        if self.ortho_init:
            self.apply(
                partial(weight_init, gain=float(self.ortho_init))
            )  # use the fact that True converts to 1.0

    def forward(self, obs):
        mu = self.trunk(obs)

        if torch.sum(torch.isnan(mu)) > 0:
            print("NaN in mu")
            print(obs)
            raise ValueError("NaN in mu")
        std = self.std_value * torch.ones_like(mu)

        dist = SquashedNormal(mu, std)

        return dist
