from __future__ import annotations

from typing import Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def apply_spectral_norm_if(layer: nn.Module, use_sn: bool) -> nn.Module:
    if use_sn and isinstance(layer, nn.Linear):
        return nn.utils.parametrizations.spectral_norm(layer)                
    return layer


def mlp(sizes: Sequence[int], activation=nn.ReLU, output_activation: nn.Module | None = None, spectral_norm: bool = False):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else output_activation
        lin = nn.Linear(sizes[i], sizes[i + 1])
        lin = apply_spectral_norm_if(lin, spectral_norm)
        layers += [lin]
        if act is not None:
            layers += [act()]
    return nn.Sequential(*layers)


class Actor(nn.Module):
    """Gaussian policy with Tanh squashing (SAC-compatible)."""

    def __init__(self, obs_dim: int, act_dim: int, hidden_dims: Sequence[int], spectral_norm: bool = False):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_dims), activation=nn.ReLU, output_activation=nn.ReLU, spectral_norm=spectral_norm)
        self.mu_layer = apply_spectral_norm_if(nn.Linear(hidden_dims[-1], act_dim), spectral_norm)
        self.log_std_layer = apply_spectral_norm_if(nn.Linear(hidden_dims[-1], act_dim), spectral_norm)

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.net(obs)
        mu = self.mu_layer(h)
        log_std = self.log_std_layer(h)
        log_std = torch.clamp(log_std, -5.0, 2.0)
        return mu, log_std

    def sample(self, obs: torch.Tensor):
        mu, log_std = self(obs)
        std = torch.exp(log_std)
        normal = torch.distributions.Normal(mu, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
                                    
        log_prob = normal.log_prob(x_t).sum(-1, keepdim=True)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(-1, keepdim=True)
        return action, log_prob

    def deterministic(self, obs: torch.Tensor):
        mu, _ = self(obs)
        return torch.tanh(mu)


class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_dims: Sequence[int], spectral_norm: bool = False):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_dims) + [1], activation=nn.ReLU, output_activation=None, spectral_norm=spectral_norm)

    def forward(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        x = torch.cat([obs, act], dim=-1)
        return self.q(x)
