"""
Multimodal alignment pooling module for question-guided feature extraction.

This module implements attention-based pooling mechanisms that align visual and audio
features with textual queries using optimal transport and cross-modal attention.
The pooling operations are guided by text prompts to extract relevant multimodal features.
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from transformers.activations import ACT2FN

from models.ot import OT_dist
from utils.registry import POOLER
from .utils import weighted_adaptive_avg_pool2d_unfold, weighted_adaptive_avg_pool3d_unfold


def process_batch_of_conversation(batch_of_conversations, tokenizer):
    """
    Process a batch of conversations for text encoding.
    
    Extracts the first message value from each conversation and tokenizes them
    for consistent batch processing.
    
    Args:
        batch_of_conversations (list): List of conversation structures
        tokenizer: Text tokenizer for encoding conversations
        
    Returns:
        dict: Tokenized batch with input_ids, attention_mask, etc.
    """
    texts = [conversation[0]["value"] for conversation in batch_of_conversations]
    return tokenizer(texts, padding=True, return_tensors="pt", return_attention_mask=True)


class MultimodalAlignmentPooling(nn.Module):
    """
    Multimodal alignment pooling module for text-guided feature extraction.
    
    This module implements a pooling mechanism that uses text features to guide
    the pooling of visual or audio features through similarity-based weighting.
    The pooling operation emphasizes features that are most relevant to the text query.
    """
    
    def __init__(self, dim_feat, dim_text, tau=1.0, kernel_size=(2, 3), stride=(2, 3)):
        """
        Initialize the multimodal alignment pooling module.
        
        Args:
            dim_feat (int): Feature dimension of visual or audio features
            dim_text (int): Text feature dimension
            tau (float): Temperature coefficient controlling similarity distribution sharpness
            kernel_size (tuple): Pooling kernel size (time, spatial dimensions)
            stride (tuple): Pooling stride (time, spatial dimensions)
        """
        super(MultimodalAlignmentPooling, self).__init__()
        self.dim_feat = dim_feat
        self.dim_text = dim_text
        self.tau = tau
        self.kernel_size = kernel_size
        self.stride = stride

        # Text feature projection layer to align with feature space
        self.text_proj = nn.Linear(dim_text, dim_feat)

    def forward(self, features, text_features):
        """
        Perform text-guided pooling of multimodal features.
        
        Computes similarity between text and visual/audio features, then applies
        weighted pooling based on the similarity scores to extract relevant regions.
        
        Args:
            features (torch.Tensor): Input features [batch_size, timesteps, spatial, dim_feat]
            text_features (torch.Tensor): Text features [batch_size, dim_text]
            
        Returns:
            torch.Tensor: Pooled features [batch_size, timesteps', spatial', dim_feat]
        """
        """
        Perform text-guided pooling of multimodal features.
        
        Computes similarity between text and visual/audio features, then applies
        weighted pooling based on the similarity scores to extract relevant regions.
        
        Args:
            features (torch.Tensor): Input features [batch_size, timesteps, spatial, dim_feat]
            text_features (torch.Tensor): Text features [batch_size, dim_text]
            
        Returns:
            torch.Tensor: Pooled features [batch_size, timesteps', spatial', dim_feat]
        """
        batch_size, timesteps, spatial, dim_feat = features.shape

        # Project text features to feature space for alignment
        text_proj = self.text_proj(text_features)  # [batch_size, dim_feat]

        # Compute similarity scores between text and visual/audio features
        features_flat = features.view(batch_size, -1, dim_feat)  # Flatten spatial-temporal dimensions
        similarity_scores = F.softmax(
            self.tau * torch.matmul(features_flat, text_proj.unsqueeze(-1)), dim=1
        )  # [batch_size, timesteps * spatial, 1]
        
        # Reshape similarity scores back to original spatial-temporal structure
        similarity_scores = similarity_scores.view(batch_size, timesteps, spatial)

        # Apply text-guided weighted pooling
        similarity_scores = similarity_scores.unsqueeze(-1)  # Add feature dimension
        weighted_features = features * similarity_scores  # Weight features by similarity
        
        # Perform adaptive pooling with specified kernel and stride
        pooled_features = F.avg_pool2d(
            weighted_features.permute(0, 3, 1, 2),  # Rearrange to [batch, features, time, spatial]
            kernel_size=self.kernel_size,
            stride=self.stride
        ).permute(0, 2, 3, 1)  # Restore to [batch, time', spatial', features]

        return pooled_features

@POOLER.register("visual_align_pooler")
class VisualAlignPooler(nn.Module):
    """
    Visual alignment pooler for question-guided visual feature extraction.
    
    This module implements a sophisticated pooling mechanism that aligns visual features
    with text queries using cross-modal attention and optimal transport. It extracts
    visually relevant regions based on textual prompts for multimodal understanding tasks.
    """
    
    def __init__(
            self, text_tokenizer, text_model, kernel, stride, pooling_temperature, 
            visual_embeds_dim, visual_projection_dim, llama_embeds_dim=4096, 
            ot_coeff=0.0, use_text_mask=False, global_act="softmax", logit_scale_init_value=1.0, **kwargs
        ):
        """
        Initialize the visual alignment pooler.
        
        Args:
            text_tokenizer: Tokenizer for processing text inputs
            text_model: Pre-trained text model for encoding queries
            kernel (int): Pooling kernel size for spatial-temporal aggregation
            stride (int): Pooling stride for downsampling
            pooling_temperature (float): Temperature for attention weight computation
            visual_embeds_dim (int): Dimension of input visual embeddings
            visual_projection_dim (int): Dimension of projected visual features
            llama_embeds_dim (int): Output dimension for LLaMA compatibility
            ot_coeff (float): Coefficient for optimal transport loss
            use_text_mask (bool): Whether to use text attention masks in OT loss
            global_act (str): Global activation function ('softmax', 'sigmoid', 'no_operation')
            logit_scale_init_value (float): Initial value for learnable logit scaling
            **kwargs: Additional arguments
        """
        super().__init__()
        self.text_tokenizer = text_tokenizer
        self.text_model = text_model
        self.visual_pooling_norm = nn.LayerNorm(visual_embeds_dim)
        self.visual_pooling_visual_projection = nn.Linear(visual_embeds_dim, visual_projection_dim, bias=True)
        self.visual_pooling_text_projection = nn.Linear(self.text_model.config.projection_dim, visual_projection_dim, bias=True)
        self.visual_pooling_linear_1 = nn.Linear(visual_embeds_dim, llama_embeds_dim, bias=True)
        self.visual_pooling_linear_2 = nn.Linear(llama_embeds_dim, llama_embeds_dim, bias=True)
        self.kernel = kernel
        self.stride = stride
        self.pooling_temperature = pooling_temperature
        self.global_act = global_act.lower()
        self.ot_coeff = ot_coeff
        self.use_text_mask = use_text_mask
        self.logit_scale_init_value = nn.Parameter(torch.tensor(logit_scale_init_value))
    
    def forward(self, visual_embeds, conversations, output_shape=None, output_attention=False):
        """
        Perform question-guided visual feature pooling.
        
        Processes visual embeddings and text conversations to extract relevant visual
        features through cross-modal attention and optimal transport alignment.
        
        Args:
            visual_embeds (torch.Tensor): Visual features [N, T, P, D] where T=time, P=patches
            conversations (list): List of conversation structures for text guidance
            output_shape (tuple, optional): Target output shape for pooling
            output_attention (bool): If True, return attention weights instead of pooled features
            
        Returns:
            tuple: (pooled_features, ot_loss) or attention_weights if output_attention=True
                - pooled_features (torch.Tensor): Text-guided pooled visual features
                - ot_loss (torch.Tensor): Optimal transport alignment loss
        """
        """
        Perform question-guided visual feature pooling.
        
        Processes visual embeddings and text conversations to extract relevant visual
        features through cross-modal attention and optimal transport alignment.
        
        Args:
            visual_embeds (torch.Tensor): Visual features [N, T, P, D] where T=time, P=patches
            conversations (list): List of conversation structures for text guidance
            output_shape (tuple, optional): Target output shape for pooling
            output_attention (bool): If True, return attention weights instead of pooled features
            
        Returns:
            tuple: (pooled_features, ot_loss) or attention_weights if output_attention=True
                - pooled_features (torch.Tensor): Text-guided pooled visual features
                - ot_loss (torch.Tensor): Optimal transport alignment loss
        """
        num_frames = visual_embeds.shape[1]
        W = int(math.sqrt(visual_embeds.shape[-2]))  # Spatial dimension (assuming square patches)
        hidden_states = rearrange(visual_embeds, 'N T P D -> (N T) P D')

        # Project visual embeddings to common space for cross-modal alignment
        visual_proj_embeds = rearrange(visual_embeds, 'N T P D -> N (T P) D')
        visual_proj_embeds = self.visual_pooling_norm(visual_proj_embeds)
        visual_proj_embeds = self.visual_pooling_visual_projection(visual_proj_embeds)
        visual_proj_embeds = visual_proj_embeds / visual_proj_embeds.norm(p=2, dim=-1, keepdim=True)

        # Extract text features using pre-trained text model
        with torch.no_grad():
            inputs = process_batch_of_conversation(conversations, tokenizer=self.text_tokenizer).to("cuda")
            bert_hidden_states = self.text_model(**inputs, output_hidden_states=True).hidden_states

        # Compute optimal transport loss if coefficient > 0
        if self.ot_coeff > 0:
            bert_text_penultimate_embeds = bert_hidden_states[-2][:, 1:, :]  # Exclude [CLS] token
            bert_text_penultimate_embeds = self.visual_pooling_text_projection(bert_text_penultimate_embeds)
            bert_text_penultimate_embeds = bert_text_penultimate_embeds / bert_text_penultimate_embeds.norm(p=2, dim=-1, keepdim=True)
            
            if self.use_text_mask:
                ot_loss = OT_dist(
                    visual_proj_embeds, 
                    bert_text_penultimate_embeds, 
                    text_mask=inputs["attention_mask"][:, 1:], 
                    got_lambda_wd=self.ot_coeff
                )
            else:
                ot_loss = OT_dist(
                    visual_proj_embeds, 
                    bert_text_penultimate_embeds, 
                    got_lambda_wd=self.ot_coeff
                )
        else:
            ot_loss = torch.tensor(0.0, device="cuda")

        # Extract text CLS embeddings for cross-modal attention
        bert_text_cls_embeds = bert_hidden_states[-1][:, 0, :]  # [CLS] token from final layer
        bert_text_cls_embeds = self.visual_pooling_text_projection(bert_text_cls_embeds)
        bert_text_cls_embeds = bert_text_cls_embeds / bert_text_cls_embeds.norm(p=2, dim=-1, keepdim=True)

        # Compute cross-modal attention weights
        logit_scale = self.logit_scale_init_value.exp()
        tv_logits = torch.einsum('ad,bvd->abv', [bert_text_cls_embeds, visual_proj_embeds])  # Text-visual similarity
        tv_logits = tv_logits.diagonal(dim1=0, dim2=1).transpose(0, 1).contiguous()  # Extract diagonal for matching pairs
        
        # Apply global activation function to attention logits
        if self.global_act == "softmax":
            tv_weights = torch.softmax(tv_logits * logit_scale, dim=-1)
        elif self.global_act == "sigmoid":
            tv_weights = torch.sigmoid(tv_logits)
        elif self.global_act == "no_operation":
            tv_weights = tv_logits
        else:
            raise ValueError(f"Unsupported global activation: {self.global_act}")

        # Return attention weights if requested
        if output_attention:
            return rearrange(tv_logits, 'N (T W H) -> N T W H', T=num_frames, W=W)

        # Reshape attention weights to spatial-temporal structure
        tv_weights = rearrange(tv_weights, 'N (T W H) -> N T W H', T=num_frames, W=W)
        
        # Apply feature transformation layers
        hidden_states = self.visual_pooling_linear_1(hidden_states)
        hidden_states = ACT2FN["gelu"](hidden_states)
        hidden_states = self.visual_pooling_linear_2(hidden_states)
        hidden_states = ACT2FN["gelu"](hidden_states)

        # Store intermediate hidden states for potential access
        self._hidden_states = rearrange(hidden_states, '(N T) (W H) D -> N (T W H) D', T=num_frames, W=W)

        # Apply weighted adaptive pooling guided by text attention
        hidden_states = rearrange(hidden_states, '(N T) (W H) D -> N D T W H', T=num_frames, W=W)
        hidden_states = weighted_adaptive_avg_pool3d_unfold(
            hidden_states, output_shape, self.kernel, self.stride, 
            tv_weights, self.pooling_temperature
        )
        
        # Reshape to final output format
        hidden_states = rearrange(hidden_states, 'N D T W H -> N T (W H) D')
        hidden_states = rearrange(hidden_states, 'B T WH D -> B (T WH) D')
        
        return hidden_states, ot_loss
    
    @property
    def hidden_states(self):
        """
        Get the intermediate hidden states from the last forward pass.
        
        Returns:
            torch.Tensor: Hidden states before final pooling operation
        """
        return self._hidden_states


@POOLER.register("audio_align_pooler")
class AudioAlignPooler(nn.Module):
    """
    Audio alignment pooler for question-guided audio feature extraction.
    
    This module implements a pooling mechanism similar to VisualAlignPooler but
    specifically designed for audio features. It aligns audio representations
    with text queries using cross-modal attention and optimal transport.
    """
    
    def __init__(
            self, text_tokenizer, text_model, kernel, stride, pooling_temperature, 
            audio_embeds_dim, audio_projection_dim, llama_embeds_dim=4096, 
            ot_coeff=0.0, use_text_mask=False, global_act="softmax", logit_scale_init_value=1.0, **kwargs
        ):
        """
        Initialize the audio alignment pooler.
        
        Args:
            text_tokenizer: Tokenizer for processing text inputs
            text_model: Pre-trained text model for encoding queries
            kernel (int): Pooling kernel size for temporal-frequency aggregation
            stride (int): Pooling stride for downsampling
            pooling_temperature (float): Temperature for attention weight computation
            audio_embeds_dim (int): Dimension of input audio embeddings
            audio_projection_dim (int): Dimension of projected audio features
            llama_embeds_dim (int): Output dimension for LLaMA compatibility
            ot_coeff (float): Coefficient for optimal transport loss
            use_text_mask (bool): Whether to use text attention masks in OT loss
            global_act (str): Global activation function ('softmax', 'sigmoid', 'no_operation')
            logit_scale_init_value (float): Initial value for learnable logit scaling
            **kwargs: Additional arguments
        """
        super().__init__()
        self.text_tokenizer = text_tokenizer
        self.text_model = text_model.eval()
        self.audio_pooling_norm = nn.LayerNorm(audio_embeds_dim)
        self.audio_pooling_audio_projection = nn.Linear(audio_embeds_dim, audio_projection_dim, bias=True)
        self.audio_pooling_text_projection = nn.Linear(self.text_model.config.projection_dim, audio_projection_dim, bias=True)
        self.audio_pooling_linear_1 = nn.Linear(audio_embeds_dim, llama_embeds_dim, bias=True)
        self.audio_pooling_linear_2 = nn.Linear(llama_embeds_dim, llama_embeds_dim, bias=True)
        self.kernel = kernel
        self.stride = stride
        self.pooling_temperature = pooling_temperature
        self.global_act = global_act.lower()
        self.ot_coeff = ot_coeff
        self.use_text_mask = use_text_mask
        self.logit_scale_init_value = nn.Parameter(torch.tensor(logit_scale_init_value))
    
    def forward(self, audio_embeds, conversations, output_shape=None, output_attention=False):
        """
        Perform question-guided audio feature pooling.
        
        Processes audio embeddings and text conversations to extract relevant audio
        features through cross-modal attention and optimal transport alignment.
        
        Args:
            audio_embeds (torch.Tensor): Audio features [N, T, B, D] where T=time, B=frequency bands
            conversations (list): List of conversation structures for text guidance
            output_shape (tuple, optional): Target output shape for pooling
            output_attention (bool): If True, return attention weights instead of pooled features
            
        Returns:
            tuple: (pooled_features, ot_loss) or attention_weights if output_attention=True
                - pooled_features (torch.Tensor): Text-guided pooled audio features
                - ot_loss (torch.Tensor): Optimal transport alignment loss
        """
        """
        Perform question-guided audio feature pooling.
        
        Processes audio embeddings and text conversations to extract relevant audio
        features through cross-modal attention and optimal transport alignment.
        
        Args:
            audio_embeds (torch.Tensor): Audio features [N, T, B, D] where T=time, B=frequency bands
            conversations (list): List of conversation structures for text guidance
            output_shape (tuple, optional): Target output shape for pooling
            output_attention (bool): If True, return attention weights instead of pooled features
            
        Returns:
            tuple: (pooled_features, ot_loss) or attention_weights if output_attention=True
                - pooled_features (torch.Tensor): Text-guided pooled audio features
                - ot_loss (torch.Tensor): Optimal transport alignment loss
        """
        time_steps = audio_embeds.shape[1]
        num_bands = audio_embeds.shape[2]  # Frequency bands in audio representation

        # Project audio embeddings to common space for cross-modal alignment
        audio_proj_embeds = rearrange(audio_embeds, 'N T B D -> N (T B) D')
        audio_proj_embeds = self.audio_pooling_norm(audio_proj_embeds)
        audio_proj_embeds = self.audio_pooling_audio_projection(audio_proj_embeds)
        audio_proj_embeds = audio_proj_embeds / audio_proj_embeds.norm(p=2, dim=-1, keepdim=True)

        # Extract text features using pre-trained text model
        with torch.no_grad():
            inputs = process_batch_of_conversation(conversations, tokenizer=self.text_tokenizer).to("cuda")
            bert_hidden_states = self.text_model(**inputs, output_hidden_states=True).hidden_states

        # Compute optimal transport loss if coefficient > 0
        if self.ot_coeff > 0:
            bert_text_penultimate_embeds = bert_hidden_states[-2][:, 1:, :]  # Exclude [CLS] token
            bert_text_penultimate_embeds = self.audio_pooling_text_projection(bert_text_penultimate_embeds)
            bert_text_penultimate_embeds = bert_text_penultimate_embeds / bert_text_penultimate_embeds.norm(p=2, dim=-1, keepdim=True)
            
            if self.use_text_mask:
                ot_loss = OT_dist(
                    audio_proj_embeds, 
                    bert_text_penultimate_embeds, 
                    text_mask=inputs["attention_mask"][:, 1:], 
                    got_lambda_wd=self.ot_coeff
                )
            else:
                ot_loss = OT_dist(
                    audio_proj_embeds, 
                    bert_text_penultimate_embeds, 
                    got_lambda_wd=self.ot_coeff
                )
        else:
            ot_loss = torch.tensor(0.0, device="cuda")

        # Extract text CLS embeddings for cross-modal attention
        bert_text_cls_embeds = bert_hidden_states[-1][:, 0, :]  # [CLS] token from final layer
        bert_text_cls_embeds = self.audio_pooling_text_projection(bert_text_cls_embeds)
        bert_text_cls_embeds = bert_text_cls_embeds / bert_text_cls_embeds.norm(p=2, dim=-1, keepdim=True)

        # Compute cross-modal attention weights
        logit_scale = self.logit_scale_init_value.exp()
        ta_logits = torch.einsum('ad,bvd->abv', [bert_text_cls_embeds, audio_proj_embeds])  # Text-audio similarity
        ta_logits = ta_logits.diagonal(dim1=0, dim2=1).transpose(0, 1).contiguous()  # Extract diagonal for matching pairs

        # Return attention weights if requested
        if output_attention:
            return rearrange(ta_logits, 'N (T B) -> N T B', T=time_steps, B=num_bands)
        
        # Apply global activation function to attention logits
        if self.global_act == "softmax":
            ta_weights = torch.softmax(ta_logits * logit_scale, dim=-1)
        elif self.global_act == "sigmoid":
            ta_weights = torch.sigmoid(ta_logits)
        elif self.global_act == "no_operation":
            ta_weights = ta_logits
        else:
            raise ValueError(f"Unsupported global activation: {self.global_act}")
        
        # Reshape attention weights to temporal-frequency structure
        ta_weights = rearrange(ta_weights, 'N (T B) -> N T B', T=time_steps, B=num_bands)

        # Apply feature transformation layers
        audio_embeds = self.audio_pooling_linear_1(audio_embeds)
        audio_embeds = ACT2FN["gelu"](audio_embeds)
        audio_embeds = self.audio_pooling_linear_2(audio_embeds)
        audio_embeds = ACT2FN["gelu"](audio_embeds)

        # Store intermediate hidden states for potential access
        self._hidden_states = rearrange(audio_embeds, 'N T B D -> N (T B) D')

        # Apply weighted adaptive pooling guided by text attention
        audio_embeds = rearrange(audio_embeds, 'N T B D -> N D T B')
        audio_embeds = weighted_adaptive_avg_pool2d_unfold(
            audio_embeds, output_shape, self.kernel, self.stride, ta_weights, self.pooling_temperature
        )
        
        # Reshape to final output format
        audio_embeds = rearrange(audio_embeds, 'N D T B -> N T B D')
        audio_embeds = rearrange(audio_embeds, 'N T B D -> N (T B) D')
        
        return audio_embeds, ot_loss
    
    @property
    def hidden_states(self):
        """
        Get the intermediate hidden states from the last forward pass.
        
        Returns:
            torch.Tensor: Hidden states before final pooling operation
        """
        return self._hidden_states