import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.configuration_utils import PretrainedConfig


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        x_dtype = hidden_states.dtype
        x = hidden_states.to(torch.float32)
        var = x.pow(2).mean(-1, keepdim=True)
        y = x * torch.rsqrt(var + self.variance_epsilon)
        y = y * self.weight
        return y.to(x_dtype)

class SamplerBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.norm = RMSNorm(out_features, eps=eps)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(F.silu(self.linear(x)))

class Sampler(nn.Module):
    def __init__(self, 
        config: PretrainedConfig,
        use_residual: bool = True,
    ):
        super().__init__()

        hidden_size = config.hidden_size
        self.use_residual = use_residual

        self.layers = nn.ModuleList([
            SamplerBlock(2 * hidden_size, hidden_size, eps=config.rms_norm_eps),
            SamplerBlock(hidden_size, hidden_size, eps=config.rms_norm_eps),
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layers[0](x)
        y = self.layers[1](x)
        if self.use_residual:
            return x + y
        return y
