
"""
Multimodal encoders for audio and video feature extraction.

This module provides various encoder implementations for processing different
modalities including video (InternVideo2), audio (Whisper, BEATs, SenseVoice),
and combined audio encoders (Whisper+BEATs) for multimodal learning tasks.
"""

from collections.abc import Mapping

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import WhisperFeatureExtractor

from .audio_encoders.beats import BEATs, BEATsConfig
from .video_encoders.internvideo2 import pretrain_internvideo2_1b_patch14_224
from .utils import load_model
from utils.registry import ENCODER
from .base_encoder import BaseEncoder


# Video Encoders
@ENCODER.register("internvideo2")
class InternVideo2Encoder(BaseEncoder):
    """
    InternVideo2 encoder for video feature extraction.
    
    This encoder processes video inputs using the InternVideo2 model to extract
    spatiotemporal features from video sequences with frame-based processing.
    """
    
    def __init__(self, configs, device):
        """
        Initialize the InternVideo2 encoder.
        
        Args:
            configs (dict): Configuration containing internvideo2 parameters
            device (torch.device): Device for model computation
        """
        super(InternVideo2Encoder, self).__init__(configs, device)
        self.internvideo2_configs = configs.get('internvideo2')
        self.num_frames = configs.get('internvideo2').get('num_frames')
        self.internvideo2 = pretrain_internvideo2_1b_patch14_224(self.internvideo2_configs).to(self.device).to(torch.float16)
        if self.internvideo2_configs.get('freeze'):
            self.freeze(self.internvideo2)

    @torch.no_grad()    
    def encode(self, video_inputs):
        """
        Encode video inputs into feature representations.
        
        Processes variable-length video sequences by padding to maximum frame count
        and extracting features in chunks based on num_frames configuration.
        
        Args:
            video_inputs (list): List of video tensors with shape [B, C, T, H, W]
            
        Returns:
            tuple: (video_embeds, padded_video_mask)
                - video_embeds (torch.Tensor): Extracted video features
                - padded_video_mask (torch.Tensor): Mask indicating valid frames
        """
        """
        Encode video inputs into feature representations.
        
        Processes variable-length video sequences by padding to maximum frame count
        and extracting features in chunks based on num_frames configuration.
        
        Args:
            video_inputs (list): List of video tensors with shape [B, C, T, H, W]
            
        Returns:
            tuple: (video_embeds, padded_video_mask)
                - video_embeds (torch.Tensor): Extracted video features
                - padded_video_mask (torch.Tensor): Mask indicating valid frames
        """
        # Find maximum frame count across all videos
        max_frame = max([video.shape[2] for video in video_inputs])
        padded_video_inputs = []
        padded_video_mask = []
        
        # Pad videos to uniform length and create attention masks
        for video in video_inputs:
            if video.size(2) < max_frame:
                diffsize = max_frame - video.size(2)
                padded_video_mask.append([1] * video.size(2) + [0] * diffsize)
                video = torch.cat([video, video.new_zeros(
                    video.size(0), video.size(1), diffsize, video.size(3), video.size(4))], dim=2)
            else:
                padded_video_mask.append([1] * video.size(2))
            padded_video_inputs.append(video)
            
        padded_video_inputs = torch.cat(padded_video_inputs, dim=0)
        padded_video_mask = torch.tensor(padded_video_mask)
        b, channel, nframes, _, _ = padded_video_inputs.shape

        # Process video in chunks of num_frames
        video_embeds_all = []
        for i in range(nframes // self.num_frames):
            video_embeds_all.append(
                self.internvideo2(padded_video_inputs[:, :, (i * self.num_frames):((i + 1) * self.num_frames), :, :])[:, :-1, :]
            )
        video_embeds = torch.cat(video_embeds_all, dim=1)
        return video_embeds, padded_video_mask
    

# Audio Encoders
@ENCODER.register("whisper")
class WhisperEncoder(BaseEncoder):
    """
    Whisper encoder for audio feature extraction.
    
    This encoder uses OpenAI's Whisper model to extract audio features from
    speech inputs, supporting segmentation for long audio sequences.
    """
    
    def __init__(self, configs, device):
        """
        Initialize the Whisper encoder.
        
        Args:
            configs (dict): Configuration containing whisper parameters
            device (torch.device): Device for model computation
        """
        super(WhisperEncoder, self).__init__(configs, device)
        from .audio_encoders.modeling_whisper import WhisperModel
        self.whisper_configs = configs.get('whisper')
        self.cache_dir = self.whisper_configs.get('cache_dir', './cache')
        self.feature_extractor = load_model(
            model_path=self.whisper_configs.get('path', "./models/Whisper"),
            model_class=WhisperFeatureExtractor, cache_dir=self.cache_dir
        )
        self.whisper = load_model(
            model_path=self.whisper_configs.get('path', "./models/Whisper"),
            model_class=WhisperModel, cache_dir=self.cache_dir
        ).encoder.to(self.device)

        self.target_sampling_rate = self.whisper_configs.get('target_audio_sampling_rate')
        if self.whisper_configs.get('freeze'):
            self.freeze(self.whisper)

    @torch.no_grad()
    def encode(self, audio_inputs=None, target_sampling_rate=16000):
        """
        Encode audio inputs into feature representations.
        
        Processes audio by segmenting long sequences, extracting spectrograms,
        and generating embeddings through the Whisper encoder.
        
        Args:
            audio_inputs (list): List of audio arrays
            target_sampling_rate (int): Target sampling rate for audio processing
            
        Returns:
            tuple: (whisper_embeds, whisper_seg_embeds)
                - whisper_embeds (torch.Tensor): Flattened audio embeddings [B, seq_len, dim]
                - whisper_seg_embeds (torch.Tensor): Segmented embeddings [B, seg, seq_len, dim]
        """ 
        """
        Encode audio inputs into feature representations.
        
        Processes audio by segmenting long sequences, extracting spectrograms,
        and generating embeddings through the Whisper encoder.
        
        Args:
            audio_inputs (list): List of audio arrays
            target_sampling_rate (int): Target sampling rate for audio processing
            
        Returns:
            tuple: (whisper_embeds, whisper_seg_embeds)
                - whisper_embeds (torch.Tensor): Flattened audio embeddings [B, seq_len, dim]
                - whisper_seg_embeds (torch.Tensor): Segmented embeddings [B, seg, seq_len, dim]
        """
        self.target_sampling_rate = self.whisper_configs.get('target_audio_sampling_rate', target_sampling_rate)
        spectrogram_list = []
        
        # Extract whisper features by segmenting long audio sequences
        for audio in audio_inputs:
            if len(audio) > self.seg_len * self.target_sampling_rate: 
                audio_list = [
                    audio[i: i + self.seg_len * self.target_sampling_rate] \
                        for i in range(0, len(audio), self.seg_len * self.target_sampling_rate)
                ]
                sample_spectrograms = []
                for audio_piece in audio_list:
                    spectrogram_piece = self.feature_extractor(
                        audio_piece,
                        sampling_rate=self.target_sampling_rate,
                        return_tensors='pt',
                        max_length=self.seg_len * self.target_sampling_rate,
                    )
                    sample_spectrograms.append(spectrogram_piece["input_features"].squeeze())
                spectrogram_list.append(sample_spectrograms)
            else:
                spectrogram = self.feature_extractor(
                    audio,
                    sampling_rate=self.target_sampling_rate,
                    return_tensors='pt',
                    max_length=self.seg_len * self.target_sampling_rate,
                )
                spectrogram_list.append([spectrogram["input_features"].squeeze()])
                
        # Generate embeddings through Whisper encoder
        whisper_embeds_list = []
        for sample_spectrograms in spectrogram_list:
            sample_spectrograms = torch.stack(sample_spectrograms, dim=0).to(torch.float16)
            whisper_embeds = self.whisper(sample_spectrograms.to(self.device), return_dict=True).last_hidden_state
            whisper_embeds_list.append(whisper_embeds)

        # Pad whisper embeddings to uniform segment count
        max_seg = max([item.shape[0] for item in whisper_embeds_list])
        whisper_embeds_padded_list = []
        for whisper_embeds in whisper_embeds_list:
            pad_size = max_seg - whisper_embeds.shape[0]
            whisper_embeds = F.pad(whisper_embeds, (0, 0, 0, 0, 0, pad_size))
            whisper_embeds_padded_list.append(whisper_embeds)

        whisper_seg_embeds = torch.stack(whisper_embeds_padded_list, dim=0)
        whisper_embeds = whisper_seg_embeds.reshape(whisper_seg_embeds.size(0), -1, whisper_seg_embeds.size(-1))
        return whisper_embeds, whisper_seg_embeds

@ENCODER.register("beats")
class BEATsEncoder(BaseEncoder):
    """
    BEATs encoder for audio feature extraction.
    
    This encoder uses BEATs (Bidirectional Encoder representation from Audio Transformers)
    for extracting rich audio representations with self-supervised learning.
    """
    
    def __init__(self, configs, device):
        """
        Initialize the BEATs encoder.
        
        Args:
            configs (dict): Configuration containing beats parameters
            device (torch.device): Device for model computation
        """
        super(BEATsEncoder, self).__init__(configs, device)
        self.beats_configs = configs.get('beats')
        self.beats = load_model(
            ckpt_path=self.beats_configs.get('path', "./models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt"),
            model_class=BEATs,
            config_class=BEATsConfig,
            cache_dir=self.beats_configs.get("cache_dir", "./cache")
        ).to(self.device)

        if self.beats_configs.get('freeze'):
            self.freeze(self.beats)

    @torch.no_grad()
    def encode(self, audio_inputs=None, target_sampling_rate=16000):
        """
        Encode audio inputs using BEATs model.
        
        Processes audio by segmenting long sequences and extracting features
        with proper padding and masking for variable-length inputs.
        
        Args:
            audio_inputs (list): List of audio arrays
            target_sampling_rate (int): Target sampling rate for audio processing
            
        Returns:
            tuple: (beats_features, beats_seg_features)
                - beats_features (torch.Tensor): Flattened audio features [B, seq_len, dim]
                - beats_seg_features (torch.Tensor): Segmented features [B, seg, seq_len, dim]
        """ 
        """
        Encode audio inputs using BEATs model.
        
        Processes audio by segmenting long sequences and extracting features
        with proper padding and masking for variable-length inputs.
        
        Args:
            audio_inputs (list): List of audio arrays
            target_sampling_rate (int): Target sampling rate for audio processing
            
        Returns:
            tuple: (beats_features, beats_seg_features)
                - beats_features (torch.Tensor): Flattened audio features [B, seq_len, dim]
                - beats_seg_features (torch.Tensor): Segmented features [B, seg, seq_len, dim]
        """
        target_sampling_rate = self.beats_configs.get('target_audio_sampling_rate', target_sampling_rate)
        audios = []
        
        # Segment long audio sequences
        for audio in audio_inputs:
            if len(audio) > self.seg_len * target_sampling_rate: 
                audio_list = [audio[i: i + self.seg_len * target_sampling_rate] for i in range(0, len(audio), self.seg_len * target_sampling_rate)]
                audios.append(audio_list)
            else:
                audios.append([audio])
                
        # Extract BEATs features for each audio sample
        beats_features_list = []
        for raw_audio in audios:
            beats_features = []
            beats_feature = [torch.from_numpy(audio).float() for audio in raw_audio]
            beats_feature_lens = torch.tensor([feature.shape[0] for feature in beats_feature])
            beats_feature = pad_sequence(beats_feature, batch_first=True, padding_value=0)
            beats_feature_mask = torch.arange(beats_feature.shape[1]).unsqueeze(0) >= beats_feature_lens.unsqueeze(1)
            beats_features.append(
                self.beats.extract_features(beats_feature.to(self.device),
                                            padding_mask=beats_feature_mask.to(self.device),
                                            feature_only=True)[0])
            
            # Pad features to uniform length within sample
            max_feature_len = max([feature.size(1) for feature in beats_features])
            for i in range(len(beats_features)):
                if beats_features[i].size(1) < max_feature_len:
                    beats_features[i] = F.pad(beats_features[i], (0, 0, 0, max_feature_len - beats_features[i].size(1)), 'constant', 0)
            beats_features = torch.cat(beats_features, dim=0)
            beats_features_list.append(beats_features)

        # Pad features across samples to uniform segment count
        max_seg = max([item.shape[0] for item in beats_features_list])
        beats_features_padded_list = []
        for beats_features in beats_features_list:
            pad_size = max_seg - beats_features.shape[0]
            beats_features = F.pad(beats_features, (0, 0, 0, 0, 0, pad_size))
            beats_features_padded_list.append(beats_features)
            
        beats_seg_features = torch.stack(beats_features_padded_list, dim=0)
        beats_features = beats_seg_features.reshape(beats_seg_features.size(0), -1, beats_seg_features.size(-1))
        return beats_features, beats_seg_features


@ENCODER.register("sensevoicesmall")
class SenseVoiceSmallEncoder(BaseEncoder):
    """
    SenseVoice Small encoder for speech recognition and feature extraction.
    
    This encoder uses the SenseVoice model for processing speech with multilingual
    support and emotion recognition capabilities.
    """
    
    def __init__(self, configs, device):
        """
        Initialize the SenseVoice Small encoder.
        
        Args:
            configs (dict): Configuration containing sensevoicesmall parameters
            device (torch.device): Device for model computation
        """
        super(SenseVoiceSmallEncoder, self).__init__(configs, device)
        from .audio_encoders.sensevoice_small import SenseVoiceSmall
        self.sensevoicesmall_configs = configs.get('sensevoicesmall')
        self.sensevoice, self.kwargs = SenseVoiceSmall.from_pretrained(
            model=self.sensevoicesmall_configs.get('path', "iic/SenseVoiceSmall"),
            device=self.device
        )
        if self.sensevoicesmall_configs.get('freeze'):
            self.freeze(self.sensevoice)

    def encode(self, audio_paths):
        """
        Encode audio files using SenseVoice model.
        
        Args:
            audio_paths (list): List of file paths to audio files
            
        Returns:
            dict: SenseVoice inference results with embeddings and predictions
        """
        sensevoicesmall_embeds = self.sensevoice.inference(
            data_in=audio_paths,
            language='en',
            use_itn=False,
            bank_emo_unk=False,
            **self.kwargs
        )
        return sensevoicesmall_embeds


@ENCODER.register('whisper_beats')
class Whisper_BEATsEncoder(BaseEncoder):
    """
    Combined Whisper and BEATs encoder for multimodal audio feature extraction.
    
    This encoder combines features from both Whisper and BEATs models to create
    richer audio representations by concatenating their complementary features.
    """
    
    def __init__(self, configs, device):
        """
        Initialize the combined Whisper-BEATs encoder.
        
        Args:
            configs (dict): Configuration containing whisper_beats parameters
            device (torch.device): Device for model computation
        """
        super(Whisper_BEATsEncoder, self).__init__(configs, device)
        self.audio_encoders_configs = configs.get('whisper_beats')
        self.beats_encoder = BEATsEncoder(self.audio_encoders_configs, device)
        self.whisper_encoder = WhisperEncoder(self.audio_encoders_configs, device)
    
    def concat_whisper_beats(self, beats_seg_features, whisper_seg_embeds):
        """
        Concatenate Whisper and BEATs features along the feature dimension.
        
        Handles dimension alignment by padding the smaller feature dimension
        to match the larger one before concatenation.
        
        Args:
            beats_seg_features (torch.Tensor): BEATs features [B, seg, seq_len, beats_dim]
            whisper_seg_embeds (torch.Tensor): Whisper features [B, seg, seq_len, whisper_dim]
            
        Returns:
            tuple: (audio_embeds, audio_seg_embeds)
                - audio_embeds (torch.Tensor): Flattened concatenated features
                - audio_seg_embeds (torch.Tensor): Segmented concatenated features
        """
        # Align sequence dimensions by padding if necessary
        if beats_seg_features.size(2) < whisper_seg_embeds.size(2):
            beats_seg_features = F.pad(
                beats_seg_features, 
                (0, 0, 0, whisper_seg_embeds.size(2) - beats_seg_features.size(2)), 
                'constant', 0
            ).to(whisper_seg_embeds.device)
            
        # Concatenate features along the last dimension
        audio_seg_embeds = torch.cat([beats_seg_features, whisper_seg_embeds], dim=-1)
        audio_embeds = audio_seg_embeds.reshape(audio_seg_embeds.size(0), -1, audio_seg_embeds.size(-1))
        return audio_embeds, audio_seg_embeds
      
    def encode(self, audio_inputs=None, target_sampling_rate=16000, sin_pos=False):
        """
        Encode audio using both Whisper and BEATs, then concatenate features.
        
        Args:
            audio_inputs (list): List of audio arrays
            target_sampling_rate (int): Target sampling rate for audio processing
            sin_pos (bool): Whether to use sinusoidal positional encoding
            
        Returns:
            tuple: Combined Whisper and BEATs audio features
                - audio_embeds (torch.Tensor): Flattened concatenated features
                - audio_seg_embeds (torch.Tensor): Segmented concatenated features
        """
        # Extract features from both encoders
        _, beats_seg_features = self.beats_encoder.encode(
            audio_inputs=audio_inputs, 
            target_sampling_rate=target_sampling_rate, 
            sin_pos=sin_pos
        )
        _, whisper_seg_features = self.whisper_encoder.encode(
            audio_inputs=audio_inputs, 
            target_sampling_rate=target_sampling_rate, 
            sin_pos=sin_pos
        )
        
        # Concatenate the features
        concated_whisper_beats = self.concat_whisper_beats(beats_seg_features, whisper_seg_features)
        return concated_whisper_beats