import torch
import math
from torch import nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions import Normal

import utils
import wandb


class TanhTransform(pyd.transforms.Transform):
    """Tanh transformation with numerically stable Jacobian."""
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        """Inverse of the tanh function."""
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        """Forward transformation."""
        return x.tanh()

    def _inverse(self, y):
        """Inverse transformation."""
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        """Log of the absolute determinant of the Jacobian for Tanh."""
        return 2. * (math.log(2.) - x - F.softplus(-2. * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    """Normal distribution followed by Tanh transformation."""

    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale
        base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(base_dist, transforms)

    @property
    def mean(self):
        """Compute the mean after transformation."""
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu


class DiagGaussianActor(nn.Module):
    """Diagonal Gaussian policy with squashed Tanh distribution."""

    def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, log_std_bounds):
        super().__init__()

        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, hidden_depth)
        self.outputs = {}
        self.apply(utils.weight_init)

    def forward(self, obs):
        """Forward pass through the actor network."""
        mu, log_std = self.trunk(obs).chunk(2, dim=-1)

        # Constrain log_std within bounds
        log_std = self._constrain_log_std(log_std)

        std = log_std.exp()
        self.outputs['mu'] = mu
        self.outputs['std'] = std

        # dist = SquashedNormal(mu, std)
        return mu, log_std

    def evaluate(self, state, epsilon=1e-6):
        mu, log_std = self.forward(state)
        std = log_std.exp()
        dist = Normal(mu, std)
        e = dist.rsample().to(state.device)
        action = torch.tanh(e)
        log_prob = (dist.log_prob(e) - torch.log(1 - action.pow(2) + epsilon)).sum(1, keepdim=True)

        return action, log_prob

    def get_action(self, state):
        """
        returns the action based on a squashed gaussian policy. That means the samples are obtained according to:
        a(s,e)= tanh(mu(s)+sigma(s)+e)
        """
        mu, log_std = self.forward(state)
        std = log_std.exp()
        dist = Normal(mu, std)
        e = dist.rsample().to(state.device)
        action = torch.tanh(e)
        return action.detach().cpu()

    def get_det_action(self, state):
        mu, log_std = self.forward(state)
        return torch.tanh(mu).detach().cpu()

    def _constrain_log_std(self, log_std):
        """Helper function to constrain log_std within predefined bounds."""
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = self.log_std_bounds
        return log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)

    def log(self):
        """Log the actor's outputs and parameters."""
        for key, value in self.outputs.items():
            wandb.log({f'train_actor/{key}_hist_mean': value.mean().item()})  # Log mean values to wandb

        # Log the parameters of the trunk network
        self._log_network_params(self.trunk)

    def _log_network_params(self, network):
        """Helper function to log network parameters."""
        for i, layer in enumerate(network):
            if isinstance(layer, nn.Linear):
                # Log to wandb
                wandb.log({
                    f'train_actor/fc{i}_weight_mean': layer.weight.mean().item(),
                    f'train_actor/fc{i}_bias_mean': layer.bias.mean().item()
                })
