
"""
Multimodal Dataset Module

This module provides a comprehensive dataset implementation for multimodal AI systems supporting
audio, video, and image modalities. It handles data loading, preprocessing, and batching for
supervised fine-tuning tasks in audio-visual question answering and captioning scenarios.

The dataset supports flexible modality combinations including audio-only, video-only, and
audio-video multimodal configurations with configurable preprocessing pipelines and caching
mechanisms for efficient training workflows.

Key Features:
    - Multimodal data loading (audio, video, image)
    - Duration-based video filtering for quality control
    - Audio resampling and caching for performance optimization
    - Video frame extraction with configurable segment lengths
    - Question-answer pair handling for conversational AI training
    - Batch collation with modality-specific processing
    - Support for both training and validation modes

Supported Modalities:
    - audio: Audio-only processing with optional raw audio caching
    - video: Video-only processing with frame extraction
    - audio_video: Combined audio-video multimodal processing
    - image: Image processing (placeholder implementation)

Dependencies:
    - torch: PyTorch framework for tensor operations and data loading
    - cv2: OpenCV for video processing and duration calculation
    - soundfile: Audio file loading and processing
    - librosa: Audio resampling and signal processing
    - pytorchvideo: Video processing and transformation utilities
    - transformers: Hugging Face transformers for feature extraction

Author: AI Model Development Team
License: MIT
"""

import json

import cv2
import soundfile as sf

import torch
from torch.utils.data import Dataset
import torch.distributed as dist
from pytorchvideo import transforms as pv_transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
from pytorchvideo.data.encoded_video import EncodedVideo
from transformers import WhisperFeatureExtractor
from models.encoders.video_encoders.internvideo2.intern_dataset.video_utils import read_frames_decord
import librosa
from .utils import get_train_transform, get_test_transform


# Global constants for video duration filtering
MIN_DURATION, MAX_DURATION = 0, 70


def load_json_data(data_path, min_duration=MIN_DURATION, max_duration=MAX_DURATION):
    """
    Load and filter JSON data based on video duration constraints.
    
    This function reads a JSON dataset file and filters out video records that don't
    meet the specified duration criteria. It uses OpenCV to extract video metadata
    and calculate duration for quality control.
    
    Args:
        data_path (str): Path to the JSON dataset file containing video records
        min_duration (float, optional): Minimum video duration in seconds. 
                                       Defaults to MIN_DURATION (0 seconds).
        max_duration (float, optional): Maximum video duration in seconds.
                                       Defaults to MAX_DURATION (70 seconds).
    
    Returns:
        list: Filtered list of video records that meet duration criteria.
             Each record is a dictionary containing video metadata and annotations.
    
    Note:
        Videos outside the duration range are excluded from the dataset to ensure
        consistent processing and training efficiency. The function uses OpenCV
        to read video properties without loading full video content.
    """
    with open(data_path, 'r') as f:
        records = json.load(f)

    new_records, remove_video_paths = [], []

    for record in records:
        video_path = record["video_path"]
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        duration = frame_count / fps
        cap.release()

        if duration < min_duration or duration > max_duration:
            remove_video_paths.append(video_path)
            continue

        new_records.append(record)

    return new_records


class MultiModalDataset(Dataset):
    """
    Dataset for multimodal supervised fine-tuning with audio, video, and image modalities.
    
    This dataset class provides comprehensive support for loading and preprocessing multimodal
    data including audio, video, and image content along with associated question-answer pairs
    for conversational AI training. It supports flexible modality combinations and includes
    optimization features such as audio caching and configurable preprocessing pipelines.
    
    The dataset is designed for training and evaluation of multimodal language models that
    can process and understand content across different sensory modalities while generating
    appropriate textual responses.
    
    Features:
        - Multimodal data loading with modality-specific preprocessing
        - Audio caching mechanism for improved training performance
        - Video frame extraction with configurable segment lengths
        - Question-answer pair handling for conversational training
        - Support for both training and validation preprocessing
        - Batch collation with proper tensor formatting
    
    Attributes:
        modality (str): Target modality configuration ('audio', 'video', 'audio_video', 'image')
        return_raw_audios (bool): Whether to return raw audio data or just file paths
        training (bool): Whether the dataset is in training mode (affects preprocessing)
        seg_len (int): Video segment length for frame extraction
        audio_resampling (bool): Whether to resample audio to target sampling rate
        audio_sampling_rate (int): Target audio sampling rate for resampling
        raw_audios_cache (dict): Cache for loaded audio data to avoid repeated I/O
    """

    def __init__(self,
                 modality,
                 data_json_path,
                 training,
                 return_raw_audios,
                 audio_resampling,
                 audio_sampling_rate,
                 seg_len,
                 ):
        """
        Initialize the MultiModalDataset with specified configuration.
        
        Args:
            modality (str): Modality configuration specifying which data types to load.
                          Supported values: 'audio', 'video', 'audio_video', 'image'
            data_json_path (str): Path to JSON file containing dataset annotations
                                 and file paths for multimodal content
            training (bool): Whether dataset is used for training (affects data augmentation)
            return_raw_audios (bool): If True, loads and returns actual audio data;
                                    if False, returns only audio file paths
            audio_resampling (bool): Whether to resample audio files to target sampling rate
            audio_sampling_rate (int): Target sampling rate for audio resampling (typically 16000)
            seg_len (int): Maximum number of frames to extract from videos for processing
        """
        super(MultiModalDataset, self).__init__()
        self.modality = modality
        self.return_raw_audios = return_raw_audios
        self.training = training
        self.seg_len = seg_len
        self.audio_resampling = audio_resampling
        self.audio_sampling_rate = audio_sampling_rate
 
        # Load and parse dataset from JSON file
        self.image_path_list, self.audio_path_list, self.video_path_list, \
            self.task_list,  self.ques_list, self.ans_list = self.get_data_json(data_json_path)
        
        # Initialize cache for raw audio data to improve performance
        self.raw_audios_cache = {} 

    def load_audio(self, audio_path, target_sampling_rate, resampling):
        """
        Load audio file with optional resampling to target sampling rate.
        
        This method loads audio files using soundfile and applies resampling if needed.
        It handles both mono and stereo audio by converting stereo to mono using
        the first channel.
        
        Args:
            audio_path (str): Path to the audio file to load
            target_sampling_rate (int): Desired sampling rate for the audio
            resampling (bool): Whether to apply resampling if current rate differs
        
        Returns:
            numpy.ndarray or None: Loaded audio data as 1D array, or None if loading fails
        
        Note:
            - Stereo audio is converted to mono by taking the first channel
            - Resampling is performed using librosa for high-quality results
            - Returns None on any loading error to allow graceful error handling
        """
        try:
            audio, sr = sf.read(audio_path)
            if len(audio.shape) == 2:
                audio = audio[:, 0]  # Convert stereo to mono
            # Apply resampling if needed and requested
            if resampling and sr != target_sampling_rate:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sampling_rate)
            return audio
        except Exception as e:
            return None

    def get_data_json(self, data_path):
        """
        Parse JSON dataset file and extract data paths and annotations.
        
        This method loads the dataset JSON file and extracts file paths for different
        modalities along with associated tasks, questions, and answers. It handles
        missing or empty fields gracefully by filtering them out.
        
        Args:
            data_path (str): Path to the JSON dataset file
        
        Returns:
            tuple: A 6-tuple containing lists of:
                - image_path_list (list): Paths to image files
                - audio_path_list (list): Paths to audio files  
                - video_path_list (list): Paths to video files
                - task_list (list): Task identifiers or types
                - ques_list (list): Question texts for QA pairs
                - ans_list (list): Answer texts for QA pairs
        
        Note:
            Empty strings in the JSON data are filtered out to ensure clean data lists.
            All lists maintain corresponding indices for proper data alignment.
        """
        json_data = load_json_data(data_path)
        video_path_list, audio_path_list, image_path_list, task_list, ques_list, ans_list = [], [], [], [], [],[]

        for item in json_data:
            one_video_path, one_audio_path, one_image_path = item["video_path"], item["audio_path"], item['image_path']
            one_task, one_ques, one_ans = item["task"], item["question"], item['answer']
            
            # Only add non-empty paths and annotations
            if one_video_path != '':
                video_path_list.append(one_video_path)
            if one_audio_path != '':
                audio_path_list.append(one_audio_path)
            if one_image_path != '':
                image_path_list.append(one_image_path)
            if one_task != '':
                task_list.append(one_task)
            if one_ques != '':
                ques_list.append(one_ques)
            if one_ans != '':
                ans_list.append(one_ans)
                
        return image_path_list, audio_path_list, video_path_list, task_list, ques_list, ans_list


    def __len__(self):
        """
        Return the number of instances in the dataset based on the active modality.
        
        The dataset length is determined by the number of available files for the
        specified modality. This ensures proper iteration over the dataset during
        training and evaluation.
        
        Returns:
            int: Number of instances available for the current modality configuration
        
        Note:
            - For image modality: returns number of image files
            - For audio modality: returns number of audio files  
            - For video and audio_video modalities: returns number of video files
            - Returns 0 if no valid modality is specified or no data is available
        """
        if self.modality == "image":
            return len(self.image_path_list) if self.image_path_list else 0
        if self.modality == "audio":
            return len(self.audio_path_list)
        elif self.modality == "audio_video":
            return len(self.video_path_list)
        elif self.modality == "video":
            return len(self.video_path_list)
        return 0

    def get_image(self, i):
        """
        Load and process image data at the specified index.
        
        This method provides a placeholder implementation for image loading.
        Future implementations should include proper image preprocessing pipelines.
        
        Args:
            i (int): Index of the image to load
        
        Returns:
            None: Placeholder return value (to be implemented)
        
        Note:
            This is a placeholder implementation. A complete version should include
            image loading, preprocessing, and proper tensor formatting.
        """
        i = i % len(self.image_path_list)
        image_path = self.image_path_list[i]
        # Placeholder for image loading and processing implementation
        return None

    def get_audio(self, i):
        """
        Load and process audio data at the specified index.
        
        This method handles audio loading with optional caching for performance optimization.
        It returns either just the audio path or includes raw audio data based on configuration.
        The method supports both audio-only and audio-video modality configurations.
        
        Args:
            i (int): Index of the audio sample to load
        
        Returns:
            dict or str or tuple: Depending on modality and configuration:
                - For audio modality: dict with audio_path, optional audio_data, output_texts, modality, task
                - For audio_video modality with raw audio: tuple of (audio_path, raw_audio)
                - For audio_video modality without raw audio: audio_path string
                - None if loading fails or invalid modality
        
        Note:
            Raw audio data is cached to avoid repeated file I/O operations during training.
            The caching mechanism significantly improves training performance for large datasets.
        """
        i = i % len(self.audio_path_list)
        audio_path = self.audio_path_list[i]
        raw_audio = None
        
        # Load raw audio data if requested and cache it
        if self.return_raw_audios:
            if audio_path not in self.raw_audios_cache:
                audio = self.load_audio(audio_path=audio_path,
                                        target_sampling_rate=self.audio_sampling_rate,
                                        resampling=self.audio_resampling)
                if audio is not None:
                    self.raw_audios_cache[audio_path] = audio
            raw_audio = self.raw_audios_cache.get(audio_path)

        # Prepare conversational output format
        output_texts = [{"from":"human","value":self.ques_list[i]},{"from":"gpt","value":self.ans_list[i]}]
        task = self.task_list[i]

        # Return data based on modality configuration
        if self.modality == 'audio_video' and self.return_raw_audios:
            return audio_path, raw_audio
        elif self.modality == 'audio_video':
            return audio_path
        if self.modality == 'audio':
            if self.return_raw_audios:
                return dict(
                    audio_path=audio_path,
                    audio_data=raw_audio,
                    output_texts=output_texts,
                    modality=self.modality,
                    task=task) 
            else:
                return dict(
                    audio_path=audio_path,
                    output_texts=output_texts,
                    modality=self.modality,
                    task=task)   
        return None

    def get_video(self, i):
        """
        Load and process video data at the specified index.
        
        This method handles video loading using decord for efficient frame extraction.
        It applies appropriate transformations based on training/validation mode and
        formats the data for both video-only and audio-video modality configurations.
        
        Args:
            i (int): Index of the video sample to load
        
        Returns:
            dict or tuple or None: Depending on modality:
                - For video modality: dict with video_path, video_data, output_texts, modality, task
                - For audio_video modality: tuple of (video_data, video_path)
                - None if loading fails or no video data available
        
        Note:
            - Video data is converted to float16 for memory efficiency
            - Different transformations are applied for training vs validation
            - Frame extraction is limited by seg_len parameter
            - Graceful error handling returns None on loading failures
        """
        if not self.video_path_list:
            return None

        i = i % len(self.video_path_list)
        video_path = self.video_path_list[i]
        try:
            # Extract video frames using decord
            video_data, frame_indices, duration = read_frames_decord(video_path=video_path, seg_len=self.seg_len)
            
            # Apply appropriate transformations based on mode
            if self.training:
                trans = get_train_transform()
            else:
                trans = get_test_transform()
            video_data = trans(video_data)
            
            # Format video tensor: [T, C, H, W] -> [C, T, H, W] with batch dimension
            video_data = video_data.permute(1, 0, 2, 3)
            video_data = video_data.unsqueeze(0)
            video_data = video_data.to(torch.float16)

            if self.modality == 'audio_video':
                return video_data, video_path
            elif self.modality == 'video':
                output_texts = [{"from":"human","value":self.ques_list[i]},{"from":"gpt","value":self.ans_list[i]}]
                task = self.task_list[i] 
                return dict(
                    video_path=video_path,
                    video_data=video_data,
                    output_texts=output_texts,
                    modality=self.modality,
                    task=task)
            return None
        except Exception as e:
            return None

    def get_audio_video(self, i):
        """
        Load and process combined audio-video data at the specified index.
        
        This method handles the loading of synchronized audio-video pairs for multimodal
        processing. It combines audio and video data with corresponding question-answer
        annotations for training multimodal conversational AI models.
        
        Args:
            i (int): Index of the audio-video sample to load
        
        Returns:
            dict or None: Dictionary containing combined audio-video data with keys:
                - video_path: Path to the video file
                - video_data: Processed video tensor
                - audio_path: Path to the audio file  
                - audio_data: Raw audio data (if return_raw_audios is True)
                - output_texts: Question-answer pairs in conversational format
                - modality: Modality type ('audio_video')
                - task: Task identifier
                Returns None if loading fails or data is unavailable.
        
        Note:
            Audio and video indices are calculated separately using modulo operation
            to handle datasets where audio and video counts may differ. The method
            ensures proper synchronization of multimodal data for training.
        """
        if not self.video_path_list or not self.audio_path_list:
            return None

        video_index = i % len(self.video_path_list)
        audio_index = i % len(self.audio_path_list)

        # Prepare conversational output format
        output_texts = [{"from":"human","value":self.ques_list[video_index]},{"from":"gpt","value":self.ans_list[video_index]}]
        task = self.task_list[video_index]

        # Load both video and audio data
        video_data_result = self.get_video(video_index)
        audio_data_result = self.get_audio(audio_index)

        if video_data_result is None or audio_data_result is None:
            return None

        if self.modality == 'audio_video':
            if self.return_raw_audios:
                audio_path = self.audio_path_list[audio_index]
                if audio_path not in self.raw_audios_cache:
                    audio = self.load_audio(audio_path=audio_path,
                                            target_sampling_rate=self.audio_sampling_rate,
                                            resampling=self.audio_resampling)
                    if audio is not None:
                        self.raw_audios_cache[audio_path] = audio
                audio_data = self.raw_audios_cache.get(audio_path)
                video_data, video_path = self.get_video(video_index)
                return dict(
                    video_path=video_path,
                    video_data=video_data,
                    audio_path=audio_path,
                    audio_data=audio_data,
                    output_texts=output_texts,
                    modality=self.modality,
                    task=task, 
                )
            else:
                audio_path = self.audio_path_list[audio_index]
                video_data, video_path = self.get_video(video_index)
                return dict(
                    video_path=video_path,
                    video_data=video_data,
                    audio_path=audio_path,
                    output_texts=output_texts,
                    modality=self.modality,
                    task=task, 
                )
        return None

    def __getitem__(self, i):
        """
        Get a data sample at the specified index based on the configured modality.
        
        This method serves as the main entry point for data loading and delegates
        to modality-specific loading methods based on the dataset configuration.
        
        Args:
            i (int): Index of the data sample to retrieve
        
        Returns:
            dict or tuple or None: Data sample in format appropriate for the modality:
                - image: Image data and metadata
                - audio: Audio data and annotations
                - video: Video data and annotations  
                - audio_video: Combined audio-video data and annotations
                Returns None for unsupported modalities or loading failures.
        
        Note:
            This method implements the standard PyTorch Dataset interface and is
            called automatically by DataLoader during training and evaluation.
        """
        if self.modality == "image":
            return self.get_image(i)
        elif self.modality == "audio":
            return self.get_audio(i)
        elif self.modality == "video":
            return self.get_video(i)
        elif self.modality == "audio_video":
            return self.get_audio_video(i)
        return None

    def collater(self, instances):
        """
        Collate multiple dataset instances into a batched format for efficient processing.
        
        This method takes a list of individual data instances and combines them into
        batch tensors suitable for model training. It handles modality-specific data
        organization and maintains proper alignment between different data types.
        
        Args:
            instances (list): List of data instances returned by __getitem__ method.
                            Each instance is a dictionary containing modality-specific data.
        
        Returns:
            dict or None: Batched data dictionary containing:
                - modality: String indicating the data modality
                - output_texts: List of question-answer pairs for the batch
                - task: List of task identifiers for the batch
                - Additional modality-specific keys (image_data, audio_data, video_data, etc.)
                Returns None if no valid instances are provided.
        
        Note:
            - Filters out None instances from failed loading attempts
            - Organizes data by modality type for efficient batch processing
            - Maintains correspondence between different data types within each sample
            - Audio data paths are always included; raw audio data is conditional
        """
        # Filter out failed loading attempts
        instances = [instance for instance in instances if instance is not None]
        if not instances:
            return None

        # Initialize data containers for different modalities
        if 'image' in self.modality:
            image_data = []
            image_path = []
        if 'audio' in self.modality:
            audio_data = []
            audio_path = []
        if 'video' in self.modality:
            video_data = []
            video_path = []

        output_texts = []
        task = []

        # Collect data from each instance
        for instance in instances:
            if 'image' in self.modality and 'image_data' in instance:
                image_data.append(instance['image_data'])
                image_path.append(instance['image_path'])
            if 'audio' in self.modality and 'audio_path' in instance:
                if self.return_raw_audios and 'audio_data' in instance:
                    audio_data.append(instance['audio_data'])
                    audio_path.append(instance['audio_path'])
                else:
                    audio_path.append(instance['audio_path'])
            if 'video' in self.modality and 'video_data' in instance:
                video_data.append(instance['video_data'])
                video_path.append(instance['video_path'])

            if 'output_texts' in instance:
                output_texts.append(instance['output_texts'])
            if 'task' in instance:
                task.append(instance['task'])

        # Create batched output dictionary
        collated_batch = {"modality": instances[0]["modality"], "output_texts": output_texts, "task": task}

        # Add modality-specific data to batch
        if 'image' in self.modality and image_data:
            collated_batch['image_data'] = image_data
            collated_batch['image_path'] = image_path
        if 'audio' in self.modality and audio_path:
            if self.return_raw_audios and audio_data:
                collated_batch['audio_data'] = audio_data 
        if 'video' in self.modality and video_data:
            collated_batch['video_data'] = list(video_data) 
            collated_batch['video_path'] = video_path

        return collated_batch
    
