# Import necessary libraries for the Audio-Visual Language Learning Model (AVLL)
import logging
from collections import defaultdict

import torch
from torch import nn
import torch.nn.functional as F

from transformers import StoppingCriteriaList, StoppingCriteria
from transformers.activations import ACT2FN

from .metrics import wups_score
from utils.dist_utils import is_main_process
from utils.load_instance import process_batch_instance

# Set up logger for this module
logger = logging.getLogger(__name__)


class StoppingCriteriaSub(StoppingCriteria):
    """
    Custom stopping criteria for text generation.
    Stops generation when a specified token appears a certain number of times.
    """

    def __init__(self, stops=[], encounters=1):
        """
        Initialize the stopping criteria.
        
        Args:
            stops (list): List of token IDs to stop on
            encounters (int): Number of encounters needed to trigger stopping
        """
        super().__init__()
        self.stops = stops
        self.ENCOUNTERS = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """
        Check if generation should stop based on the current input_ids.
        
        Args:
            input_ids (torch.LongTensor): Current generated token IDs
            scores (torch.FloatTensor): Generation scores
            
        Returns:
            bool: True if generation should stop, False otherwise
        """
        stop_count = 0
        for stop in self.stops:
            stop_count = (stop == input_ids[0]).sum().item()
        if stop_count >= self.ENCOUNTERS:
            return True
        return False


class AVLLModel(nn.Module):
    """
    Audio-Visual Language Learning Model (AVLL).
    
    This model integrates audio and visual encoders with a large language model (LLM)
    to perform multimodal understanding and generation tasks. It uses optimal transport
    losses for alignment between modalities and includes question-guided pooling.
    """
    
    @property
    def device(self):
        """Get the device of the model parameters."""
        return list(self.parameters())[0].device

    def __init__(
            self,
            modality, sub_task, 
            return_raw_audios, audio_size, video_size,
            audio_encoders, video_encoders,                                        # encoders
            visual_pooler, audio_pooler,                                           # alignment_pooling
            connector,                                                             # connector
            ot_av, ot_at, ot_vt,                                                   # ot
            llama_model, llama_tokenizer, max_tgt_len,                             # llm
            prompt_template, prompt_path, test_prompt_path, max_txt_len, end_sym,  # prompt
            **kwargs
        ):
        """
        Initialize the AVLL model with all necessary components.
        
        Args:
            modality (str): Modalities to use ('audio', 'video', or both)
            sub_task (bool): Whether to track subtask-specific metrics
            return_raw_audios (bool): Whether to return raw audio data
            audio_size (int): Size of audio embeddings
            video_size (int): Size of video embeddings
            audio_encoders: Audio encoding module
            video_encoders: Video encoding module
            visual_pooler: Visual feature pooling module
            audio_pooler: Audio feature pooling module
            connector: Module to connect multimodal features to LLM
            ot_av: Optimal transport loss for audio-visual alignment
            ot_at: Optimal transport loss for audio-text alignment
            ot_vt: Optimal transport loss for visual-text alignment
            llama_model: The language model (LLaMA)
            llama_tokenizer: Tokenizer for the language model
            max_tgt_len (int): Maximum target sequence length
            prompt_template (str): Template for prompts
            prompt_path (str): Path to training prompts
            test_prompt_path (str): Path to testing prompts
            max_txt_len (int): Maximum text length
            end_sym (str): End symbol for generation
        """
        super(AVLLModel, self).__init__()

        # Runtime configuration
        self.modality = modality  # Supported modalities (audio/video)
        self.sub_task = sub_task  # Enable subtask-specific tracking

        # Dataset configuration
        self.return_raw_audios = return_raw_audios  # Whether to use raw audio data
        # self.attention_mask = attention_mask

        # Feature encoders for different modalities
        self.audio_encoders = audio_encoders  # Audio encoding pipeline
        self.video_encoders = video_encoders  # Video encoding pipeline

        # Alignment pooling modules for question-guided feature extraction
        self.visual_pooler = visual_pooler  # Visual feature pooling
        self.audio_pooler = audio_pooler    # Audio feature pooling

        # Connector to fuse multimodal features for LLM input
        self.connector = connector

        # Optimal transport losses for cross-modal alignment
        self.ot_av = ot_av  # Audio-visual alignment loss
        self.ot_at = ot_at  # Audio-text alignment loss
        self.ot_vt = ot_vt  # Visual-text alignment loss

        # Language model components
        self.llama_tokenizer = llama_tokenizer  # Tokenizer for text processing
        self.llama_model = llama_model          # Pre-trained language model
        
        # Generation and embedding size configurations
        self.max_tgt_len = max_tgt_len          # Maximum target length for generation
        self.audio_size = audio_size or 25      # Audio embedding size
        self.video_size = video_size or 256     # Video embedding size

        # Prompt and template configurations
        self.prompt_template = prompt_template   # Template for formatting prompts
        self.prompt_path = prompt_path          # Path to training prompts
        self.test_prompt_path = test_prompt_path # Path to testing prompts
        self.end_sym = end_sym                  # End-of-sequence symbol
        self.prompt = None                      # Current prompt

        # Activation function and additional parameters
        self.act = ACT2FN["gelu"]  # GELU activation function
        self.kwargs = kwargs       # Additional keyword arguments

    def prompt_wrap(
            self, fused_states, input_ids, target_ids, attention_mask,
    ): 
        """
        Wrap multimodal features with prompt template for LLM input.
        
        This method combines the prompt template, multimodal features, and text
        to create the final input embeddings for the language model.
        
        Args:
            fused_states (torch.Tensor): Fused audio-visual embeddings [batch, seq_len, dim]
            input_ids (torch.Tensor): Question and answer text token IDs [batch, seq_len]
            target_ids (torch.Tensor): Target labels (-100 for positions without loss) [batch, seq_len]
            attention_mask (torch.Tensor): Attention mask for input_ids [batch, seq_len]
            
        Returns:
            tuple: (inputs_embeds, targets, attention_mask, modality_lengths)
                - inputs_embeds: Combined embeddings for LLM input
                - targets: Combined target labels
                - attention_mask: Combined attention mask
                - modality_lengths: Length of the modality prefix
        """

        # Move tensors to the correct device
        input_ids = input_ids.to(self.device)
        target_ids = target_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        batch_size = fused_states.shape[0]
        
        # Process prompt template (e.g., "USER: <img>:")
        p_before = self.prompt_template
        p_before_tokens = self.llama_tokenizer(
            p_before,
            return_tensors="pt",
            add_special_tokens=False
        ).to(self.device)
        
        # Get embeddings for prompt template
        p_before_embeds = self.llama_model.model.model.embed_tokens(
            p_before_tokens.input_ids.to(self.llama_model.model.model.device)).expand(batch_size, -1,-1)

        # Get embeddings for the text input
        p_after_embeds = self.llama_model.model.model.embed_tokens(
            input_ids.to(self.llama_model.model.model.device)).expand(batch_size, -1, -1)

        # Add beginning-of-sequence token
        bos = self.llama_tokenizer.bos_token_id * torch.ones(
            [batch_size, 1], 
            dtype=p_before_tokens.input_ids.dtype,
            device=p_before_tokens.input_ids.device
        )
 
        bos_embeds = self.llama_model.model.model.embed_tokens(
            bos.to(self.llama_model.model.model.device)
        )

        # Concatenate all embeddings: [BOS] + [prompt] + [multimodal_features] + [text]
        inputs_embeds = torch.cat([bos_embeds, p_before_embeds, fused_states, p_after_embeds], dim=1)
        
        # Create target labels (no loss for BOS, prompt, and multimodal features)
        empty_targets = torch.ones(
            [batch_size, 1 + p_before_embeds.size()[1] + fused_states.size(1)], dtype=torch.long
        ).to(self.device).fill_(-100)
        
        # Create attention mask for prefix (all ones for attended positions)
        atts_prefix = torch.ones([batch_size, 1 + p_before_embeds.size(1) + fused_states.size(1)], dtype=torch.long).to(self.device)
        
        # Combine targets and attention masks
        targets = torch.cat([empty_targets, target_ids], dim=1)
        attention_mask = torch.cat([atts_prefix, attention_mask.to(atts_prefix.device)], dim=1)
        modality_lengths = atts_prefix.size(1)
        
        # Ensure targets and attention mask have the same size
        assert attention_mask.size() == targets.size()

        return inputs_embeds, targets, attention_mask, modality_lengths

    def compute_wups(self, result, targets, threshold=0.9):
        """
        Compute WUPS (Word-level Understanding Performance Score) metric.
        
        This metric evaluates the word-level understanding performance by comparing
        generated answers with target answers using a similarity threshold.
        
        Args:
            result (dict): Model output containing logits
            targets (torch.Tensor): Target token IDs
            threshold (float): Similarity threshold for WUPS calculation
            
        Returns:
            float: Average WUPS score across the batch
        """
        # Decode generated answers from logits
        generated_answers = [
            " ".join([self.tokenizer.decode([token]) for token in tokens if token != self.tokenizer.pad_token_id])
            for tokens in torch.argmax(result['logits'], dim=-1)
        ]
        
        # Decode target answers
        target_answers = [
            " ".join([self.tokenizer.decode([token]) for token in tokens if token != self.tokenizer.pad_token_id])
            for tokens in targets
        ]

        # Calculate WUPS score for each pair and return average
        total_wups = 0
        for gen_answer, target_answer in zip(generated_answers, target_answers):
            total_wups += wups_score(gen_answer, target_answer, threshold)
        return total_wups / len(generated_answers)

    def preprocess(self, samples):
        """
        Extract and preprocess features from audio and video inputs.
        
        This method processes raw audio and video data through their respective
        encoders to obtain feature representations suitable for downstream tasks.
        
        Args:
            samples (dict): Input samples containing audio_data/audio_path and video_data
            
        Returns:
            dict: Dictionary containing extracted features for each modality
                - "audio": Audio features [batch, n_frames, audio_size, dim]
                - "video": Video features [batch, n_frames, video_size, dim]
        """
        extract_feats = {}
        
        # Process audio modality
        if 'audio' in self.modality:
            # Use raw audio data or audio file paths based on configuration
            audio_input = samples['audio_data'] if self.return_raw_audios else samples['audio_path']
            # Encode audio with target sampling rate and sinusoidal positional encoding
            audio_embeds, _ = self.audio_encoders.encode(audio_input, target_sampling_rate=16000, sin_pos=True)  # [batch, n_frames*20, 1280]
            # Reshape to [batch, n_frames, audio_size, dim] format
            extract_feats["audio"] = audio_embeds.view(audio_embeds.size(0), -1, self.audio_size, audio_embeds.size(-1))

        # Process video modality
        if 'video' in self.modality:
            video_input = samples.get('video_data')
            # Encode video frames
            video_embeds, _ = self.video_encoders.encode(video_input)
            # Reshape to [batch, n_frames, video_size, dim] format (0.5s per frame)
            extract_feats["video"] = video_embeds.view(video_embeds.size(0), -1, self.video_size, video_embeds.size(-1)).half()  # [batch, 140, 256, 1408]
            
        return extract_feats

    def forward(self, samples, verbose=False):
        """
        Forward pass of the AVLL model for training.
        
        This method processes multimodal inputs through feature extraction, alignment
        pooling, optimal transport losses, and language model forward pass.
        
        Args:
            samples (dict): Input samples containing:
                - output_texts: Target text sequences
                - audio_data/audio_path: Audio input
                - video_data: Video input
            verbose (bool): Whether to return detailed loss information
            
        Returns:
            dict: Dictionary containing various loss components:
                - loss: Total combined loss
                - sft_loss: Supervised fine-tuning loss from LLM
                - ot_av_loss_before/after: Audio-visual optimal transport losses
                - ot_at_loss: Audio-text optimal transport loss
                - ot_vt_loss: Visual-text optimal transport loss
                - Additional statistics if verbose=True
        """
        output_texts = samples['output_texts']

        # Process text inputs through tokenizer
        input_ids, target_ids, attention_mask, instructs = process_batch_instance(
            self.llama_tokenizer,
            output_texts,
            max_tgt_len=self.max_tgt_len,
            prompt=self.prompt,
            generate=False,  
        )
        
        # Extract multimodal features
        extract_feats = self.preprocess(samples)

        # Apply question-guided pooling for each modality
        if "audio" in extract_feats:
            audio_feats = extract_feats["audio"]
            # Apply audio pooler with question guidance
            audio_pool_feats, _ = self.audio_pooler(audio_feats, output_texts)
            audio_hidden_states = self.audio_pooler.hidden_states
        
        if "video" in extract_feats:
            visual_feats = extract_feats["video"]
            # Apply visual pooler with question guidance
            visual_pool_feat, _ = self.visual_pooler(visual_feats, output_texts)
            visual_hidden_states = self.visual_pooler.hidden_states

        # Compute optimal transport losses for cross-modal alignment
        ot_av_loss_before = self.ot_av(audio_hidden_states, visual_hidden_states, stage="before")
        ot_av_loss_after = self.ot_av(audio_pool_feats, visual_pool_feat, stage="after")
        ot_at_loss = self.ot_at(output_texts, audio_hidden_states)
        ot_vt_loss = self.ot_at(output_texts, visual_hidden_states)

        # Fuse multimodal features through connector
        inputs_llama = self.connector([audio_pool_feats, visual_pool_feat])

        # Prepare inputs for language model
        inputs_llama, targets, attention_mask, modality_lengths = self.prompt_wrap(
            inputs_llama, input_ids, target_ids, attention_mask
        ) 

        # Forward pass through language model
        result = self.llama_model(
            inputs_embeds=inputs_llama.to(torch.float16).to(self.llama_model.device),
            attention_mask=attention_mask.to(self.llama_model.device),
            return_dict=True,
            labels=targets.to(self.llama_model.device),
        )

        # Combine all loss components
        sft_loss = result["loss"]  # Supervised fine-tuning loss
        loss = sft_loss + ot_av_loss_after + ot_av_loss_before + ot_at_loss + ot_vt_loss 
                       
        if verbose:
            return {
                "loss": loss,
                "sft_loss": sft_loss,
                "ot_at_loss": ot_at_loss,
                "ot_vt_loss": ot_vt_loss,
                "ot_av_loss_before": ot_av_loss_before,
                "ot_av_loss_after": ot_av_loss_after,
                **self.count_statistic(samples, targets, result)
            }
        return {
            "loss": loss,
            "sft_loss": sft_loss,
            "ot_at_loss": ot_at_loss,
            "ot_vt_loss": ot_vt_loss,
            "ot_av_loss_before": ot_av_loss_before,
            "ot_av_loss_after": ot_av_loss_after,
        }

    def generate(self, samples, generate_cfg):
        """
        Generate responses for multi-turn conversations using the AVLL model.
        
        This method handles text generation for evaluation, supporting multi-turn
        conversations by iteratively processing conversation history and generating
        responses.
        
        Args:
            samples (dict): Input samples containing conversation data
            generate_cfg (dict): Generation configuration parameters including:
                - max_new_tokens: Maximum number of tokens to generate
                - num_beams: Number of beams for beam search
                - do_sample: Whether to use sampling
                - top_p: Top-p sampling parameter
                - repetition_penalty: Penalty for repetitions
                - length_penalty: Length penalty for generation
                
        Returns:
            list: Generated text responses for each conversation turn
            
        Raises:
            ValueError: If conversation has less than 2 turns
            AssertionError: If batch size is not 1 (only supports single conversation)
        """
        # Process batch instance for generation
        output_texts = samples['output_texts']
        if len(output_texts[0]) < 2:
            raise ValueError("Need at least 2 turns for multi-turn test!")
        assert len(output_texts) == 1, "Only support bsz=1 for multi turn test!"

        # Store all generated text responses
        all_gen_text = []
        
        # Process each conversation turn
        for n in range(len(output_texts[0]) // 2):
            # Extract current turn's question-answer pair
            tmp_texts = [
                [
                    output_texts[0][2 * n], output_texts[0][2 * n + 1]
                ]
            ]
            if n == 0:
                # First turn - use original conversation
                all_convs = tmp_texts
            else:
                # Subsequent turns - append conversation history
                all_convs = [
                    [
                        {
                            'from': 'human',
                            'value': f'{all_convs[0][0]["value"]}\nASSISTANT: {gen_text}\n USER: {output_texts[0][2 * n]["value"]}',  #if not self.use_llama2 else f'{all_convs[0][0]["value"]} [/INST] {gen_text}\n [INST]: {output_texts[0][2 * n]["value"]}',
                        },
                        output_texts[0][2 * n + 1]
                    ]
                ]

            # Prepare input for generation

            gen_input_ids, gen_target_ids, gen_attention_mask, instructs = process_batch_instance(
                self.llama_tokenizer,
                output_texts,
                max_tgt_len=self.max_tgt_len,
                generate=True,
                prompt=self.prompt,
            )
            
            # Get text embeddings for question guidance
            text_embeds = self.llama_model.model.model.embed_tokens(
                gen_input_ids.to(self.llama_model.model.model.device)).expand(gen_input_ids.shape[0], -1, -1)

            # Process multimodal features
            hidden_states = []
            
            # Process image modality (placeholder)
            if 'image' in self.modality:
                pass
                
            # Process audio modality with question-guided pooling
            if 'audio' in self.modality:
                if self.return_raw_audios:
                    audio_input = samples['audio_data']
                else:
                    audio_input = samples['audio_path']
                # Encode audio features
                audio_embeds, _ = self.audio_encoders._encoder(audio_input, target_sampling_rate=16000, sin_pos=True)  # [2, nframes*20, 1280]
                audio_embeds = audio_embeds.view(audio_embeds.size(0), -1, self.audio_size, audio_embeds.size(-1)).to(self.device) # [b,nframe,audio_size,d1] wrt audio
                
                # Apply question-guided audio pooling
                audio_hidden_states, OTloss_at = self.audio_alignment_pooling(text_embeds,
                                                                              audio_embeds,
                                                                              logit_scale_init_value=self.audio_logit_scale_init_value,
                                                                              output_shape=self.audio_output_shape,
                                                                              kernel=self.audio_kernel,
                                                                              stride=self.audio_stride,
                                                                              pooling_temperature=self.audio_pooling_temperature,
                                                                              output_attention=self.audio_output_attention,
                                                                              got_lambda_wd=self.audio_got_lambda_wd,
                                                                              pooling_size=self.audio_pooling_size)
                hidden_states.append(audio_hidden_states)
                
            # Process video modality with question-guided pooling
            if 'video' in self.modality:
                video_input = samples.get('video_data')
                video_embeds, _ = self.video_encoders._encoder(video_input)
                video_embeds = video_embeds.view(video_embeds.size(0), -1, self.video_size, video_embeds.size(-1)).to(self.device) # [b,nframe,video_size,d0] 0.5s/frame [2, 140, 256, 1408]
                
                # Apply question-guided visual pooling
                video_hidden_states, OTloss_vt = self.visual_alignment_pooling(text_embeds,
                                                                               video_embeds,
                                                                               logit_scale_init_value=self.visual_logit_scale_init_value,
                                                                               output_shape=self.visual_output_shape,
                                                                               kernel=self.visual_kernel,
                                                                               stride=self.visual_stride,
                                                                               pooling_temperature=self.visual_pooling_temperature,
                                                                               output_attention=self.visual_output_attention,
                                                                               got_lambda_wd=self.visual_got_lambda_wd,
                                                                               pooling_size=self.visual_pooling_size)
                hidden_states.append(video_hidden_states)

            # Fuse multimodal features
            fused_embeds = self.connector(hidden_states)
            inputs_llama = fused_embeds
            atts_llama = None
            
            # Prepare final input for generation
            gen_inputs_embeds, gen_targets, gen_attention_mask, gen_modality_lengths = self.prompt_wrap(
                inputs_llama,
                gen_input_ids,
                gen_target_ids,
                gen_attention_mask,
            )

            # Set default generation configuration
            if not isinstance(generate_cfg, dict):
                generate_cfg = {}

            # Dynamic LoRA alpha adjustment (if applicable)
            from peft.tuners.lora import LoraLayer
            def modify_lora_layer(model, lora_alpha):
                """Modify LoRA layer parameters for fine-tuning control."""
                for name, layer in model.named_children():
                    if isinstance(layer, LoraLayer):
                        layer.lora_alpha['default'] = lora_alpha
                        layer.scaling['default'] = lora_alpha / layer.r['default']
                    if isinstance(layer, nn.Module):
                        modify_lora_layer(layer, lora_alpha)

            # Apply LoRA alpha if specified in generation config
            if len(generate_cfg) != 0:
                lora_alpha = generate_cfg.get("lora_alpha", self.args.get('yu_lora_alpha', 32))
                modify_lora_layer(self.llama_model, lora_alpha)

            # Set up stopping criteria for generation
            stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2], encounters=1)])
            
            # Generate response
            gen_outputs = self.llama_model.generate(
                inputs_embeds=gen_inputs_embeds.to(self.llama_model.device),
                max_new_tokens=generate_cfg.get("max_new_tokens", self.max_tgt_len),
                min_length=generate_cfg.get("min_length", 1),
                do_sample=generate_cfg.get("do_sample", False),
                num_beams=generate_cfg.get("num_beams", 5),
                repetition_penalty=generate_cfg.get("repetition_penalty", 1.5),
                length_penalty=generate_cfg.get("length_penalty", 1.0),
                top_p=generate_cfg.get("top_p", 0.9),
                stopping_criteria=stopping_criteria,
            )
            
            # Decode generated text
            gen_text = self.llama_tokenizer.batch_decode(gen_outputs, add_special_tokens=False)
            all_gen_text.append(gen_text)

            # Restore original LoRA alpha
            if len(generate_cfg) != 0:
                modify_lora_layer(self.llama_model, self.args.get('yu_lora_alpha', 32))
                
        return [all_gen_text]


    def count_statistic(self, samples, targets, result):
        """
        Count accuracy statistics for model evaluation.
        
        This method calculates token-level accuracy statistics and optionally
        breaks them down by task type for detailed analysis.
        
        Args:
            samples (dict): Input samples containing task information
            targets (torch.Tensor): Target token IDs
            result (dict): Model output containing logits
            
        Returns:
            dict: Dictionary containing:
                - correct: Total number of correct tokens
                - total: Total number of valid tokens
                - *_by_type_*: Task-specific statistics (if sub_task is enabled)
        """
        task_types = samples['task'] # [[major, minor], [major, minor], ...]

        if self.sub_task:
            # Extract major and minor task categories
            task_type_major = [task[0] for task in task_types]
            task_type_minor = [task[0] + '_' + task[1] for task in task_types]

            # Initialize counters for each task type
            correct_by_type_major = defaultdict(int)
            total_by_type_major = defaultdict(int)
            correct_by_type_minor = defaultdict(int)
            total_by_type_minor = defaultdict(int)
            n_sample_by_type_major = defaultdict(int)
            n_sample_by_type_minor = defaultdict(int)

        # Get predicted tokens from model logits
        chosen_tokens = torch.argmax(result['logits'], dim=-1)
        chosen_tokens = chosen_tokens[:,:-1]  # Remove last token
        labels = targets[:, 1:]  # Remove first token (shift for next-token prediction)
        bsz = labels.shape[0]
        
        # Calculate accuracy: compare predictions with targets
        correct = (chosen_tokens.reshape(-1) == labels.to(chosen_tokens.device).reshape(-1)).to(torch.long)
        valid_mask = (labels != -100).reshape(-1)  # -100 indicates tokens to ignore

        # Count correct predictions among valid tokens
        valid_tokens = correct & valid_mask.to(correct.device)
        correct = valid_tokens.sum().item()
        total = valid_mask.sum().item()

        # Reshape for per-sample analysis
        valid_tokens = valid_tokens.reshape(bsz, -1)
        valid_mask = valid_mask.reshape(bsz, -1)

        if self.sub_task:
            # Aggregate statistics by major task categories
            for i, task in enumerate(task_type_major):
                correct_by_type_major[task] += valid_tokens[i].sum().item()
                total_by_type_major[task] += valid_mask[i].sum().item()
                n_sample_by_type_major[task] += 1

            # Aggregate statistics by minor task categories
            for i, task in enumerate(task_type_minor):
                correct_by_type_minor[task] += valid_tokens[i].sum().item()
                total_by_type_minor[task] += valid_mask[i].sum().item()
                n_sample_by_type_minor[task] += 1

            # Log detailed statistics if in main process
            if is_main_process():
                logger.debug(f'correct_by_type_major:{correct_by_type_major}')
                logger.debug(f'total_by_type_major:{total_by_type_major}')
                logger.debug(f'correct_by_type_minor:{correct_by_type_minor}')
                logger.debug(f'total_by_type_minor:{total_by_type_minor}')
                logger.debug(f'n_sample_by_type_major:{n_sample_by_type_major}')
                logger.debug(f'n_sample_by_type_minor:{n_sample_by_type_minor}')

        if self.sub_task:
            return {"correct": correct, "total": total, "correct_by_type_major": correct_by_type_major, \
                    "total_by_type_major": total_by_type_major, "correct_by_type_minor": correct_by_type_minor, \
                        "total_by_type_minor": total_by_type_minor, "n_sample_by_type_major": n_sample_by_type_major, \
                        "n_sample_by_type_minor": n_sample_by_type_minor}
        return {"correct": correct, "total": total}
    

def build_avllm(
        audio_encoder, video_encoder, visual_pooler, audio_pooler, 
        connector, ot_av, ot_at, ot_vt, llama_model, llama_tokenizer, configs
    ):
    """
    Factory function to build an Audio-Visual Language Learning Model (AVLL).
    
    This function instantiates an AVLLModel with all the necessary components
    and configurations extracted from the provided config object.
    
    Args:
        audio_encoder: Pre-trained audio encoder module
        video_encoder: Pre-trained video encoder module  
        visual_pooler: Visual feature pooling module
        audio_pooler: Audio feature pooling module
        connector: Multimodal feature connector
        ot_av: Audio-visual optimal transport module
        ot_at: Audio-text optimal transport module
        ot_vt: Visual-text optimal transport module
        llama_model: Pre-trained language model
        llama_tokenizer: Tokenizer for the language model
        configs: Configuration object containing model parameters
        
    Returns:
        AVLLModel: Fully configured AVLL model instance
    """
    return AVLLModel(
        modality=configs.run.modality,                                             # Supported modalities
        sub_task=configs.run.sub_task,                                            # Enable subtask tracking
        return_raw_audios=configs.datasets.return_raw_audios,                     # Audio data format
        audio_size=configs.llm.audio_size,                                       # Audio embedding size
        video_size=configs.llm.video_size,                                       # Video embedding size
        audio_encoders=audio_encoder,                                             # Audio encoding pipeline
        video_encoders=video_encoder,                                             # Video encoding pipeline
        visual_pooler=visual_pooler,                                              # Visual feature pooler
        audio_pooler=audio_pooler,                                                # Audio feature pooler
        connector=connector,                                                       # Multimodal connector
        ot_av=ot_av,                                                              # Audio-visual OT loss
        ot_at=ot_at,                                                              # Audio-text OT loss
        ot_vt=ot_vt,                                                              # Visual-text OT loss
        llama_model=llama_model,                                                  # Language model
        llama_tokenizer=llama_tokenizer,                                          # Tokenizer
        max_tgt_len=configs.llm.prompt.get("max_tgt_len", 512),                  # Max generation length
        max_txt_len=configs.llm.prompt.get("max_txt_len", 512),                  # Max input text length
        prompt_template=configs.llm.prompt.get("prompt_template", "USER: <Img>"), # Prompt template
        prompt_path=configs.llm.prompt.get("prompt_path", "prompts/train_prompt.json"), # Training prompts
        test_prompt_path=configs.llm.prompt.get("test_prompt_path", "prompts/test_prompt.json"), # Test prompts
        end_sym=configs.llm.prompt.get("end_sym", "</s>")                        # End-of-sequence symbol
    )
