"""BigBench Audio preprocessor module for AU-Harness framework.

This module provides a preprocessor for the BigBenchAudio dataset, designed for
Speech Query Question Answering (SQQA) tasks with audio processing capabilities.
"""

import logging
from typing import Dict, List, Any

from datasets import Dataset
import numpy as np
from tqdm import tqdm

from preprocessors.base import Preprocessor

logger = logging.getLogger(__name__)


class BigBenchAudioPreprocessor(Preprocessor):
    """
    A preprocessor for the BigBenchAudio dataset, designed for
    Speech Query Question Answering (SQQA) tasks.
    """

    def process(self, dataset: Dataset, task_config: Dict[str, Any], 
                run_config: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the BigBenchAudio dataset to ensure consistent audio format and structured data.

        Parameters:
        - dataset: Dataset that needs to be pre-processed
        - task_config: Dictionary containing task configuration parameters
        - run_config: Dictionary containing run configuration parameters
            
        Returns:
        - List[Dict[str, Any]]: A list of dictionaries where each dictionary represents a sample,
          including the audio array resampled to 16kHz, metadata, and target label.
        """

        # Get dataset info
        dataset_keys = list(dataset.features.keys())
        dataset_size = len(dataset)
        self.log_dataset_info(dataset_keys, dataset_size)

        # Get dataset filters
        length_filter, num_samples_filter = self.get_dataset_filters(run_config.get('filter', None), dataset_size)

        processed_data = []
        total_duration = 0
        sample_count = 0

        # Extract relevant information from task_config
        modality = task_config.get('modality', 'audio')
        target_column_name = task_config.get('target_column', None)
        sample_instruction_column_name = task_config.get('instruction_column', None)
        user_prompt = task_config.get('user_prompt', '')

        if (not target_column_name):
            raise ValueError("[_big_bench_audio_preprocessor_] Target column name is missing. Preprocessing needs reference answers. Aborting!")

        for row in tqdm(dataset, desc="Processing samples"):
            # Create record by accessing each feature by index
            sample_id = row["id"]

            # Ensure official answer exists. If not, skip!
            if not row[target_column_name]:
                logger.warning("[%s] Missing official answer. Skipping sample.", sample_id)
                continue

            model_target = row[target_column_name].strip()
            audio_content_in_text = row[sample_instruction_column_name].strip()
            if modality == "text":
                audio_data = {
                    "array": np.array([]),  # Placeholder, not used in text-only evals
                    "sampling_rate": 16000
                }
                instruction = user_prompt + audio_content_in_text
            else:
                audio_data = row["audio"]

                # Validate audio data structure
                if not isinstance(audio_data, dict):
                    logger.warning(f"[{sample_id}] Invalid audio format. Skipping sample.")
                    continue

                # Convert to NumPy array
                audio_array = np.array(audio_data.get("array"))
                sr = audio_data.get("sampling_rate")

                if sr is None:
                    logger.warning(f"[{sample_id}] Sampling rate missing. Assuming 16kHz.")
                    sr = 16000

                # Use base class method to resample audio
                audio_array, sr = self.resample_audio(audio_array, sr)

                # Calculate audio duration in seconds
                audio_duration = len(audio_array) / sr
                total_duration += audio_duration

                # Apply length filtering if specified
                if (length_filter):
                    if not self.check_audio_length(audio_array, sr, length_filter):
                        continue
                if (num_samples_filter):
                    if sample_count >= num_samples_filter:
                        break
                
                # For audio modality, we can define a generic instruction
                # TODO: An override will need to pass this added instruction too. Consider a cleaner way to handle this.
                instruction = user_prompt + f"Answer the question provided in the audio."

            # Create structured sample
            sample = {
                "id": sample_id,
                "category": row["category"],
                "audio_content_in_text": audio_content_in_text,
                "array": audio_array if modality == "audio" else audio_data["array"],
                "sampling_rate": sr if modality == "audio" else audio_data["sampling_rate"],
                "model_target": model_target,
                "instruction": instruction,
            }

            processed_data.append(sample)
            sample_count += 1

        self.log_dataset_info(dataset_keys, dataset_size, sample_count, total_duration)
        return processed_data
