import math
import torch

class NormalizedLinear(torch.nn.Linear):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bias = None

        self.weight.data /= self.weight.norm(dim=-1, keepdim=True)
        w_shape = self.weight.shape
        self._correction_factor: float = math.sqrt(w_shape[0] / w_shape[1])  # head_dim / embedding_dim.

    def forward(self, input) -> torch.Tensor:
        return super().forward(input) / self._correction_factor


class SwiGLU(torch.nn.Module):

    def __init__(self, dim: int = 1000, ff_dim_factor: int = 4):
        super().__init__()
        self.dim = dim
        self.Wu = NormalizedLinear(dim, ff_dim_factor * dim)
        self.Wv = NormalizedLinear(dim, ff_dim_factor * dim)
        self.v_scale = math.sqrt(self.Wv.weight.shape[
                                     1])  # <- dim instead of ff_dim * dim # remove where you came from and times what you expect.
        # sqrt (TO -> From )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = self.Wu(x)
        v = self.Wv(x)
        v_scaled = v * self.v_scale

        acf = u * v_scaled * torch.sigmoid(v_scaled)
        return acf


### MC estimation of E[||swiglu||]
# Assumption: Linears map from d=1000 to d=4000
input_dim = 1000
ff_dim_factor = 4
num_mc_samples = 10_000

swiglu = SwiGLU(dim=input_dim, ff_dim_factor=ff_dim_factor)
x = torch.randn(num_mc_samples, input_dim)  # <- use enough samples
x /= x.norm(dim=-1, keepdim=True)
mc_norm_estimate = swiglu(x).norm(dim=-1).mean()
print(f'{mc_norm_estimate}')
print(f'normalizing factor = {1 / mc_norm_estimate}')