import torch
import torch.nn as nn

class FidelityGateNet(nn.Module):
    def __init__(self, c_in, hidden_dim=64):
        """
        c_in: Number of channels in the attribution map (usually same as raw data)
        hidden_dim: Number of filters in the internal layers
        """
        super(FidelityGateNet, self).__init__()

        # Layer 1: Captures local patterns
        self.conv1 = nn.Sequential(
            nn.Conv1d(c_in, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )

        # Layer 2: Refines the features
        self.conv2 = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )

        # Layer 3: Output projection
        # We use c_in as c_out because the mask shape must match the input shape
        self.out_projection = nn.Conv1d(hidden_dim, c_in, kernel_size=3, padding=1)
        
        # Binary classification for each "pixel" in the map
        self.activation = nn.Sigmoid() 

    def forward(self, x):
        # x shape: [Batch, Time, Channels] -> Convert to [Batch, Channels, Time]
        x = x.transpose(1, 2)
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.out_projection(x)
        
        x = self.activation(x)
        
        # Return to [Batch, Time, Channels]
        return x.transpose(1, 2)


