import torch
import torch.nn as nn
import torch.utils.data
from torch import autograd

from rsl_rl.utils import utils


class AMPDiscriminator(nn.Module):
    def __init__(
            self, input_dim, hidden_layer_sizes, device):
        super(AMPDiscriminator, self).__init__()

        self.device = device
        self.input_dim = input_dim
        amp_layers = []
        curr_in_dim = input_dim
        for hidden_dim in hidden_layer_sizes:
            amp_layers.append(nn.Linear(curr_in_dim, hidden_dim))
            amp_layers.append(nn.ReLU())
            curr_in_dim = hidden_dim
        self.trunk = nn.Sequential(*amp_layers).to(device)
        self.amp_linear = nn.Linear(hidden_layer_sizes[-1], 1).to(device)

        self.trunk.train()
        self.amp_linear.train()

    def forward(self, x):
        h = self.trunk(x)
        d = self.amp_linear(h)
        # return d
        return torch.tanh(d)  # 使用 tanh 限制输出范围在 [-1, 1]

    def compute_grad_pen(self, amp_obs, lambda_=10):
        expert_data = amp_obs.clone()
        expert_data.requires_grad = True

        disc = self.amp_linear(self.trunk(expert_data))
        ones = torch.ones(disc.size(), device=disc.device)
        grad = autograd.grad(
            outputs=disc, inputs=expert_data,
            grad_outputs=ones, create_graph=True,
            retain_graph=True, only_inputs=True)[0]

        # Enforce that the grad norm approaches 0.
        grad_pen = lambda_ * (grad.norm(2, dim=1) - 0).pow(2).mean()
        return grad_pen