import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentivePooling(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, dropout=0.2):
        super().__init__()
        self.att_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=True),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1, bias=True),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask_for_pooling=None):  # x: [B, N, D]
        attn_scores = self.att_proj(x)  # [B, N, 1]
        if mask_for_pooling is not None:
            mask_for_pooling = mask_for_pooling.unsqueeze(-1)  # [B, N, 1]
            attn_scores = attn_scores.masked_fill(mask_for_pooling == 0, -float('inf'))
        attn_weights = F.softmax(attn_scores, dim=1)  # [B, N, 1]
        attn_weights = self.dropout(attn_weights)     # Apply dropout
        weighted_sum = (x * attn_weights).sum(dim=1)  # [B, D]
        return weighted_sum

class CompactHasher(nn.Module):
    def __init__(self, input_dim=768, output_dim=256, hidden_dim=512):
        super().__init__()
        self.att_pool = AttentivePooling(input_dim)
        self.proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x, mask_for_pooling=None):  # x: [B, N, D]
        pooled = self.att_pool(x, mask_for_pooling)  # [B, D]
        return self.proj(pooled)   # [B, output_dim]

class FingerprintGenerator(torch.nn.Module):
    def __init__(self, input_dim=768, lstm_hidden_dim=512, lstm_layers=2, output_dim=256):
        super().__init__()
        self.bilstm = torch.nn.LSTM(
            input_size=input_dim,
            hidden_size=lstm_hidden_dim//2,
            num_layers=lstm_layers,
            dropout=0.25,
            batch_first=True,
            bidirectional=True,
        )
        
        self.compact_hasher = CompactHasher(input_dim=lstm_hidden_dim, output_dim=output_dim)
        self.tanh = nn.Tanh()
        
    def multiscale_feature_extraction(self, sequence, window_sizes=[20, 50, 100], stride=10):
        sequence = sequence.transpose(0, 1).unsqueeze(0)  # [1, D, T]
        pooled_all = []

        for win in window_sizes:
            if sequence.size(-1) < win:
                continue  # skip if sequence too short
            pooled = F.avg_pool1d(sequence, kernel_size=win, stride=stride)  # [1, D, L]
            pooled_all.append(pooled.squeeze(0).transpose(0, 1))  # [L, D]
        return torch.cat(pooled_all, dim=0) if pooled_all else torch.empty(0, sequence.size(1))

    def forward(self, input_values, attention_mask):
        # attention_mask: [B, T], input_values: [B, T, D]
        if attention_mask is not None:
            masked_input = input_values * attention_mask.unsqueeze(-1).float()
            input_values = masked_input
            
        self.bilstm.flatten_parameters()
        bilstm_out, _ = self.bilstm(input_values) # [B, T, D]
        
        if attention_mask is not None:
            actual_lengths = attention_mask.sum(dim=1).tolist() # List of integers
            actual_lengths = [int(l) for l in actual_lengths]  
        else:
            actual_lengths = [input_values.size(1)] * input_values.size(0)

        all_pooled_features = []
        for i in range(bilstm_out.size(0)):
            current_actual_length = actual_lengths[i]
            seq_item = bilstm_out[i, :current_actual_length, :] 
            pooled_for_item = self.multiscale_feature_extraction(seq_item)
            all_pooled_features.append(pooled_for_item)


        max_L = max(p.size(0) for p in all_pooled_features) if all_pooled_features else 0
        padded_for_hasher_list = []
        pooled_attention_mask_list = [] # For AttentivePooling inside CompactHasher

        for p_features in all_pooled_features: # p_features is [L_i, D_bilstm_out]
            num_pooled_frames = p_features.size(0) # This is L_i
            pad_amount = max_L - num_pooled_frames
            
            padded_p = F.pad(p_features, (0, 0, 0, pad_amount)) 
            padded_for_hasher_list.append(padded_p)

            current_pooled_mask = torch.zeros(max_L, dtype=torch.bool, device=p_features.device)
            if num_pooled_frames > 0:
                 current_pooled_mask[:num_pooled_frames] = True
            pooled_attention_mask_list.append(current_pooled_mask)

        stacked_for_hasher = torch.stack(padded_for_hasher_list)
        pooled_attention_mask = torch.stack(pooled_attention_mask_list) if max_L > 0 else None # Handle if max_L is 0

        v = self.compact_hasher(stacked_for_hasher, pooled_attention_mask)
        return self.tanh(v)