"""
Utility functions for model management and operations.

This module provides utilities for model saving/loading, parameter freezing,
and custom stopping criteria for text generation. It includes support for
PEFT (Parameter-Efficient Fine-Tuning) models and distributed training.
"""

import os
from pathlib import Path
import logging

import torch
import torch.distributed as dist
from transformers import StoppingCriteria
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)


class Stoppingcriteriasub(StoppingCriteria):
    """
    Custom stopping criteria for text generation based on specific token sequences.
    
    This class extends the transformers StoppingCriteria to implement custom
    stopping conditions during text generation. It stops generation when
    specified token sequences (stops) are encountered in the generated text.
    """

    def __init__(self, stops=[], encounters=1):
        """
        Initialize the stopping criteria.
        
        Args:
            stops (list): List of token sequences that should trigger stopping
            encounters (int): Number of encounters needed to stop (currently unused)
        """
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """
        Check if generation should stop based on the current input sequence.
        
        Examines the end of the current input sequence to see if it matches
        any of the predefined stopping sequences.
        
        Args:
            input_ids (torch.LongTensor): Current sequence of generated token IDs
            scores (torch.FloatTensor): Model output scores (unused in this implementation)
            
        Returns:
            bool: True if generation should stop, False otherwise
        """
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def save_peft_model(model, path):
    """
    Save a PEFT-wrapped model and its associated components to disk.
    
    This function saves the PEFT model, tokenizer, and custom multimodal components
    (visual pooler, audio pooler, connector) in a distributed training environment.
    Only the rank 0 process performs the actual saving to avoid conflicts.
    
    Args:
        model: The AVLLM model containing PEFT-wrapped LLaMA and multimodal components
        path (str): Directory path where the model components will be saved
        
    Note:
        This function should be called in a distributed training context where
        torch.distributed is initialized.
    """
    """
    Save a PEFT-wrapped model and its associated components to disk.
    
    This function saves the PEFT model, tokenizer, and custom multimodal components
    (visual pooler, audio pooler, connector) in a distributed training environment.
    Only the rank 0 process performs the actual saving to avoid conflicts.
    
    Args:
        model: The AVLLM model containing PEFT-wrapped LLaMA and multimodal components
        path (str): Directory path where the model components will be saved
        
    Note:
        This function should be called in a distributed training context where
        torch.distributed is initialized.
    """
    if dist.get_rank() == 0:
        # Create output directory
        Path(path).mkdir(parents=True, exist_ok=True)
        
        # Save PEFT model (adapters and configuration)
        model.llama_model.save_pretrained(path)
        logger.info("Save peft model successfully")

        # Save embedding weights separately for proper restoration
        torch.save(
            model.llama_model.model.model.embed_tokens.weight, 
            os.path.join(path, "embedding_weight.pt")
        )

        # Save tokenizer with its configuration
        model.llama_tokenizer.save_pretrained(os.path.join(path, "tokenizer"))
        logger.info("Save tokenizer successfully")
        
        # Save multimodal components state dictionaries
        torch.save(model.visual_pooler.state_dict(), os.path.join(path, "visual_pooler.pt"))
        logger.info("Save visual_pooler successfully")
        torch.save(model.audio_pooler.state_dict(), os.path.join(path, "audio_pooler.pt"))
        logger.info("Save audio_pooler successfully")
        torch.save(model.connector.state_dict(), os.path.join(path, "connector.pt"))
        logger.info("Save connector successfully")


def load_peft_model(visual_pooler, audio_pooler, connector, path):
    """
    Load a PEFT-wrapped model and its associated components from disk.
    
    This function reconstructs the complete AVLLM model by loading the PEFT
    model, tokenizer, and custom multimodal components from their saved states.
    
    Args:
        visual_pooler: Pre-initialized visual pooler module to load weights into
        audio_pooler: Pre-initialized audio pooler module to load weights into  
        connector: Pre-initialized connector module to load weights into
        path (str): Directory path where the model components are saved
        
    Returns:
        tuple: (peft_model, llama_tokenizer, visual_pooler, audio_pooler, connector)
            - peft_model: Loaded PEFT-wrapped LLaMA model
            - llama_tokenizer: Loaded tokenizer
            - visual_pooler: Visual pooler with loaded weights
            - audio_pooler: Audio pooler with loaded weights  
            - connector: Connector with loaded weights
    """
    """
    Load a PEFT-wrapped model and its associated components from disk.
    
    This function reconstructs the complete AVLLM model by loading the PEFT
    model, tokenizer, and custom multimodal components from their saved states.
    
    Args:
        visual_pooler: Pre-initialized visual pooler module to load weights into
        audio_pooler: Pre-initialized audio pooler module to load weights into  
        connector: Pre-initialized connector module to load weights into
        path (str): Directory path where the model components are saved
        
    Returns:
        tuple: (peft_model, llama_tokenizer, visual_pooler, audio_pooler, connector)
            - peft_model: Loaded PEFT-wrapped LLaMA model
            - llama_tokenizer: Loaded tokenizer
            - visual_pooler: Visual pooler with loaded weights
            - audio_pooler: Audio pooler with loaded weights  
            - connector: Connector with loaded weights
    """
    # Load PEFT configuration and base model
    peft_config = PeftConfig.from_pretrained(path)
    base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path)
    peft_model = PeftModel.from_pretrained(base_model, path)
    logger.info("Load peft model successfully")

    # Restore embedding weights that may have been modified during training
    base_model.model.embed_tokens.weight.data.copy_(
        torch.load(os.path.join(path, "embedding_weight.pt"))
    )

    # Load tokenizer with its configuration
    llama_tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, "tokenizer"))
    logger.info("Load tokenizer successfully")

    # Load multimodal components state dictionaries
    visual_pooler.load_state_dict(torch.load(os.path.join(path, "visual_pooler.pt")))
    logger.info("Load visual_pooler successfully")
    audio_pooler.load_state_dict(torch.load(os.path.join(path, "audio_pooler.pt")))
    logger.info("Load audio_pooler successfully")
    connector.load_state_dict(torch.load(os.path.join(path, "connector.pt")))
    logger.info("Load connector successfully")

    return peft_model, llama_tokenizer, visual_pooler, audio_pooler, connector


def freeze_modules(avllm_model, module_names):
    """
    Freeze specified modules in the AVLLM model to prevent parameter updates.
    
    This function sets requires_grad=False for all parameters in the specified
    modules, which is useful for partial fine-tuning or feature extraction scenarios.
    
    Args:
        avllm_model: The AVLLM model containing various modules to potentially freeze
        module_names (list): List of module names (as strings) to freeze
        
    Raises:
        ValueError: If any specified module name doesn't exist in the model
        
    Examples:
        >>> freeze_modules(model, ['visual_encoder', 'audio_encoder'])
        >>> freeze_modules(model, ['llama_model'])  # Freeze LLM while training multimodal components
    """
    """
    Freeze specified modules in the AVLLM model to prevent parameter updates.
    
    This function sets requires_grad=False for all parameters in the specified
    modules, which is useful for partial fine-tuning or feature extraction scenarios.
    
    Args:
        avllm_model: The AVLLM model containing various modules to potentially freeze
        module_names (list): List of module names (as strings) to freeze
        
    Raises:
        ValueError: If any specified module name doesn't exist in the model
        
    Examples:
        >>> freeze_modules(model, ['visual_encoder', 'audio_encoder'])
        >>> freeze_modules(model, ['llama_model'])  # Freeze LLM while training multimodal components
    """
    for module_name in module_names:
        if hasattr(avllm_model, module_name):
            # Freeze all parameters in the specified module
            for param in getattr(avllm_model, module_name).parameters():
                param.requires_grad = False
            logger.info(f"Freezed {module_name} module")
        else:
            raise ValueError(f"AVLLM model has no attribute {module_name}")
                