import torch
import torch.nn as nn


class ModuloClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, pretrained_path=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.fc2 = nn.Linear(seq_len*d_model, d_model)
        self.fc3 = nn.Linear(d_model, vocab_size)
        self.relu = nn.ReLU(inplace=False)
    def forward(self, x):
        x = self.embedding(x).clone()  # Ensure no in-place modifications
        x = (x.reshape(x.size(0), -1).clone())  # Apply ReLU after reshape
        x = self.fc2(x)
        x = self.relu(x)  # Ensure inplace=False for ReLU
        x = self.fc3(x)  # Final layer
        return x
    

# class ModuloClassifier(nn.Module):
    # def __init__(self, vocab_size, d_model, seq_len,  pretrained_path=None,rank=15):
    #     super().__init__()
    #     self.embedding = nn.Embedding(vocab_size, d_model)

    #     # Replace fc2 with a low-rank factorization: W = U @ V
    #     self.fc2_u = nn.Linear(seq_len * d_model, rank, bias=False)  # First low-rank factor
    #     self.fc2_v = nn.Linear(rank, d_model, bias=False)            # Second low-rank factor
    #     self.fc3_u = nn.Linear(d_model, rank, bias=False)  # First low-rank factor
    #     self.fc3_v = nn.Linear(rank, vocab_size, bias=False) 
    #     self.relu = nn.ReLU(inplace=False)

    # def forward(self, x):
    #     x = self.embedding(x).clone()  # Ensure no in-place modifications
    #     x = x.reshape(x.size(0), -1).clone()  # Flatten sequence and embedding dimensions
    #     x = self.fc2_u(x)  # Apply the first low-rank factor
    #     x = self.fc2_v(x)  # Apply the second low-rank factor
    #     x = self.relu(x)   # Apply ReLU activation
    #     x = self.fc3_u(x)  # Apply the first low-rank factor
    #     x = self.fc3_v(x)  # Apply the second low-rank factor
    #     return x

    
class ModuloClassifier_noEmb(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, pretrained_path=None):
        super().__init__()
        self.fc2 = nn.Linear(seq_len, d_model)
        self.fc3 = nn.Linear(d_model, vocab_size)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = x.float()
        x= self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x