import torch
import torch.nn as nn
import torch.nn.functional as F

class hiddenDetector(nn.Module):
    def __init__(self, input_dim=2048, num_filters=64, layer_kernel_size=3, dropout=0.6, pooling='max'):
        super().__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.kernel_sizes = [(layer_kernel_size, 2), (layer_kernel_size, 3), (layer_kernel_size, 5)]
        
        # self.kernel_sizes = [(1, 1)]
        
        # 使用 ModuleList 并行定义多个 Conv2d
        self.convs = nn.ModuleList([
            nn.Conv2d(
                in_channels=input_dim,
                out_channels=num_filters,
                kernel_size=k,
                padding='same' # 保持维度一致，方便处理
            )
            for k in self.kernel_sizes
        ])
        
        self.bns = nn.ModuleList([nn.BatchNorm2d(num_filters) for _ in self.kernel_sizes])

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(self.kernel_sizes), 1)

        if pooling == 'max':
            self.max_pooling = True
            self.mean_pooling = False
        elif pooling == 'mean':
            self.max_pooling = False
            self.mean_pooling = True
        else:
            raise ValueError("Pooling must be either 'max' or 'mean'.")
    
    def forward(self, x, mask=None):
        """
        x:      (batch_size, layer_size, seq_len, input_dim)
        mask:   (batch_size, seq_len)
        """

        # if self.training:
        #     # 生成噪声：均值0，标准差0.1 (根据实际情况调整，0.01-0.1 之间)
        #     noise = torch.randn_like(x) * 0.01
        #     x = x + noise

        x = x.permute(0, 3, 1, 2)  # (batch_size, input_dim, layer_size, seq_len)
        poold_outputs = []
        
        for conv, bn in zip(self.convs, self.bns):
            features = conv(x)  # (batch_size, num_filters, layer_size, seq_len)
            features = bn(features)
            features = self.relu(features)  # (batch_size, num_filters, layer_size, seq_len)
            
            if mask is not None:
                mask_expanded = mask.view(mask.size(0), 1, 1, mask.size(1))
                if self.max_pooling:
                    features = features.masked_fill(mask_expanded == 0, float('-inf'))
                if self.mean_pooling:
                    features = features.masked_fill(mask_expanded == 0, 0.0)
        
            if self.max_pooling:
                pooled_w, _ = torch.max(features, dim=3)  # (batch_size, num_filters, layer_size)
                pooled_h, _ = torch.max(pooled_w, dim=2)  # (batch_size, num_filters)

            if self.mean_pooling:
                pooled_w = torch.mean(features, dim=3)  # (batch_size, num_filters, layer_size)
                pooled_h = torch.mean(pooled_w, dim=2)  # (batch_size, num_filters)
            
            poold_outputs.append(pooled_h)

        pooled_features = torch.cat(poold_outputs, dim=1)
        pooled_features = self.dropout(pooled_features)
        logits = self.fc(pooled_features).squeeze(1)  # (batch_size)

        return logits


class linearProbe(nn.Module):
    def __init__(self, input_dim=2048):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.fc = nn.Linear(input_dim, 1)
        self.bn = nn.BatchNorm1d(input_dim)
    
    def forward(self, x, mask=None):
        """
        x:      (batch_size, input_dim)
        """
        x = self.bn(x)
        logits = self.fc(x).squeeze(1)  # (batch_size)
        return logits

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            logits = self.forward(x)
            probs = torch.sigmoid(logits)
        return probs