import torch
import torch.nn as nn
import torch.nn.functional as F

from spec_benchmark.Engine.models.base import RMSNorm

from spec_benchmark.Engine.models.base import ModelArgs

class SamplerBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)
        self.norm = RMSNorm(out_features, eps=eps)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(F.silu(self.linear(x)))

class Sampler(nn.Module):
    def __init__(self, 
        config: ModelArgs,
        use_residual: bool = True,
    ):
        super().__init__()
        norm_eps = config.norm_eps
        hidden_size = config.dim
        self.use_residual = use_residual

        self.layers = nn.ModuleList([
            SamplerBlock(2 * hidden_size, hidden_size, eps=norm_eps),
            SamplerBlock(hidden_size, hidden_size, eps=norm_eps),
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layers[0](x)
        y = self.layers[1](x)
        if self.use_residual:
            return x + y
        return y