import torch.nn as nn
from models.classifier import SAB

class MLPFeatureSelector(nn.Module):
    def __init__(self, input_dim: int, num_slots: int, num_blocks: int, hidden_dim: int = 512, dropout: float = 0.3):
        super().__init__()
        
        self.num_slots = num_slots
        self.num_blocks = num_blocks

        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_slots * num_blocks)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        mask_flattened = self.network(x)
        mask = mask_flattened.view(x.size(0), self.num_slots, self.num_blocks)
        return mask
    
    
class SetTransformerFeatureSelector(nn.Module):
    def __init__(self, input_dim: int, num_slots: int, num_blocks: int, dim_hidden: int = 128, num_heads: int = 4, ln: bool = True, dropout: float = 0.1):
        super(SetTransformerFeatureSelector, self).__init__()
        
        self.num_slots = num_slots
        self.num_blocks = num_blocks
        
        # Set Transformer blocks
        self.enc = nn.Sequential(
            SAB(dim_in=input_dim, dim_out=dim_hidden, num_heads=num_heads, ln=ln),
            nn.Dropout(dropout),
            SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln),
        )
        
        # Generate mask for each slot
        self.mask_generator = nn.Sequential(
            nn.Linear(dim_hidden, dim_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_hidden, num_blocks)
        )

    def forward(self, x):
        x = self.enc(x)
        mask = self.mask_generator(x)
        return mask
