import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from utils import initialize_uniformly

class ResBlock4(nn.Module):
    def __init__(self, hidden: int):
        super().__init__()
        self.fc1 = nn.Linear(hidden, hidden); self.ln1 = nn.LayerNorm(hidden)
        self.fc2 = nn.Linear(hidden, hidden); self.ln2 = nn.LayerNorm(hidden)
        self.fc3 = nn.Linear(hidden, hidden); self.ln3 = nn.LayerNorm(hidden)
        self.fc4 = nn.Linear(hidden, hidden); self.ln4 = nn.LayerNorm(hidden)

    def forward(self, x):
        h = F.relu(self.ln1(self.fc1(x)))
        h = F.relu(self.ln2(self.fc2(h)))
        h = F.relu(self.ln3(self.fc3(h)))
        h = self.ln4(self.fc4(h))
        return x + h

class Critic_r(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 256, num_blocks: int = 2):
        super().__init__()
        self.in_proj = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.ModuleList([ResBlock4(hidden) for _ in range(num_blocks)])
        self.out = nn.Linear(hidden, 1)
        initialize_uniformly(self.out)

    def forward(self, state):
        x = self.in_proj(state)
        for blk in self.blocks:
            x = blk(x); x = F.relu(x, inplace=True)
        return self.out(x)  # [B,1] => h(s) up to const

class Actor(nn.Module):
    """Tanh-squashed Gaussian with Jacobian correction. a = 2 * tanh(u) ∈ [-2,2]."""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.hidden1 = nn.Linear(in_dim, 128)
        self.mu_layer = nn.Linear(128, out_dim)
        self.log_std_layer = nn.Linear(128, out_dim)
        initialize_uniformly(self.mu_layer)
        initialize_uniformly(self.log_std_layer)
        self.min_log_std = -5.0
        self.max_log_std = 2.0
        self.action_scale = 2.0

    @staticmethod
    def _tanh_correction(u):
        # log(1 - tanh(u)^2) = 2*(log2 - u - softplus(-2u))
        return 2.0 * (np.log(2.0) - u - F.softplus(-2.0 * u))

    def forward(self, state):
        x = F.relu(self.hidden1(state))
        mu = self.mu_layer(x)
        log_std = torch.clamp(self.log_std_layer(x), self.min_log_std, self.max_log_std)
        std = torch.exp(log_std)
        normal = Normal(mu, std)
        eps = torch.randn_like(mu)
        u = mu + std * eps
        a = torch.tanh(u) * self.action_scale
        logp = (normal.log_prob(u) - self._tanh_correction(u)).sum(dim=-1)
        return a, logp

    def deterministic(self, state):
        x = F.relu(self.hidden1(state))
        mu = self.mu_layer(x)
        return torch.tanh(mu) * self.action_scale
