import math
import torch
import torch.nn as nn

__all__ = [
    "StateActionProjector",
    "GoalProjector",
    "cold_start_linear",
]


def cold_start_linear(linear: nn.Linear, epsilon: float = 1e-12) -> None:
    """
    将线性层权重初始化到极小范围，从而实现“cold start”。
    """
    nn.init.uniform_(linear.weight, -epsilon, epsilon)
    if linear.bias is not None:
        nn.init.uniform_(linear.bias, -epsilon, epsilon)


def _build_mlp(
    input_dim: int,
    hidden_dims: list[int],
    activation: nn.Module = nn.ReLU(inplace=True),
    use_layer_norm: bool = True,
) -> nn.Sequential:
    layers: list[nn.Module] = []
    in_dim = input_dim
    for hidden_dim in hidden_dims:
        linear = nn.Linear(in_dim, hidden_dim)
        nn.init.kaiming_normal_(linear.weight, nonlinearity="relu")
        nn.init.zeros_(linear.bias)
        layers.append(linear)
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim))
        layers.append(activation)
        in_dim = hidden_dim
    return nn.Sequential(*layers)


class StateActionProjector(nn.Module):
    """
    将状态-动作特征映射到对比空间的投影头。
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        projection_dim: int,
        use_layer_norm: bool = True,
    ) -> None:
        super().__init__()
        if not hidden_dims:
            raise ValueError("hidden_dims 至少包含一层。")
        self.backbone = _build_mlp(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            use_layer_norm=use_layer_norm,
        )
        self.proj = nn.Linear(hidden_dims[-1], projection_dim)
        cold_start_linear(self.proj)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = self.proj(x)
        return x


class GoalProjector(nn.Module):
    """
    将目标状态映射到对比空间的投影头。
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        projection_dim: int,
        use_layer_norm: bool = True,
    ) -> None:
        super().__init__()
        if not hidden_dims:
            raise ValueError("hidden_dims 至少包含一层。")
        self.backbone = _build_mlp(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            use_layer_norm=use_layer_norm,
        )
        self.proj = nn.Linear(hidden_dims[-1], projection_dim)
        cold_start_linear(self.proj)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = self.proj(x)
        return x