import torch
import torch.nn as nn

class AmortizedExplainerFCN(nn.Module):
    def __init__(self, c_in, c_out, hidden_dim=128, dropout=0.2):
        """
        c_in: Input channels (Raw + Saliency = 2 * original_channels)
        c_out: Output channels (original_channels)
        """
        super(AmortizedExplainerFCN, self).__init__()

        # Layer 1: Captures broad temporal patterns
        self.conv1 = nn.Sequential(
            nn.Conv1d(c_in, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Layer 2: Refines intermediate features
        self.conv2 = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Layer 3: Focuses on local sensitivity
        self.conv3 = nn.Sequential(
            nn.Conv1d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(), 
            nn.Dropout(dropout)
        )

        # Output Projection: Maps back to original channel space
        self.out_projection = nn.Conv1d(hidden_dim, c_out, kernel_size=1)
        
        # Sigmoid ensures output importance scores are in [0, 1] range
        self.activation = nn.Sigmoid()

    def forward(self, x):
        # Input x: [Batch, Time, Channels] -> Conv1d expects [Batch, Channels, Time]
        x = x.transpose(1, 2)
        # print(x.shape)
        # exit()
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.out_projection(x)
        x = self.activation(x)
        
        # Return to [Batch, Time, Channels] for loss calculation
        return x.transpose(1, 2)

