import torch
from torch import nn

from coarsebind_public.mol_encoder.models.loose_modules.activations import SwiGLU
from coarsebind_public.mol_encoder.models.loose_modules.norms import RMSNorm


class SwiGLUResNet(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.0):
        """
        10/25 - added dropout.
        """
        super().__init__()
        self.net = nn.Sequential(  # nn.LayerNorm(d_in),
            torch.nn.Dropout(p=dropout),
            nn.Linear(d_in, 2 * d_out),
            SwiGLU(),
            nn.Linear(d_out, d_out),
        )
        self.norm = RMSNorm(d_out)

    def forward(self, x):
        x = self.norm(x)
        return self.net(x) + x


class SwiGLUResNet_v2(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.0):
        super().__init__()
        # Initial normalization on the input for stability
        self.initial_norm = nn.LayerNorm(d_in)

        self.net = nn.Sequential(
            nn.Linear(d_in, 2 * d_out),
            SwiGLU(),
            torch.nn.Dropout(p=dropout),
            nn.Linear(d_out, d_out),
        )
        # Final normalization is applied AFTER the residual connection
        self.final_norm = nn.LayerNorm(d_out)

    def forward(self, x):
        # First, normalize the input
        normed_x = self.initial_norm(x)
        # Then, apply the residual connection and normalize the final result
        return self.final_norm(self.net(normed_x) + x)
