"""
Encoder Utilities Module

This module provides utility functions for loading and managing encoder models in multimodal
AI systems. It supports both pretrained model loading from Hugging Face Hub and custom 
checkpoint loading with configuration restoration.

The utilities handle model instantiation with proper device mapping, caching mechanisms,
and error handling for robust model loading workflows in production environments.

Key Features:
    - Pretrained model loading with caching support
    - Custom checkpoint loading with configuration restoration
    - Flexible model class and configuration specification
    - Device-agnostic loading with CPU fallback
    - Progress tracking for long-running model loading operations

Supported Loading Methods:
    - Hugging Face pretrained models (from_pretrained)
    - Custom PyTorch checkpoints with embedded configurations
    - Automatic fallback between loading strategies

Dependencies:
    - torch: PyTorch framework for model operations and checkpoint handling
    - tqdm: Progress bar utilities for user feedback during loading

Author: AI Model Development Team
License: MIT
"""

from tqdm import tqdm
import torch

def load_model(model_path=None, device=None, ckpt_path=None, model_class=None, config_class=None, cache_dir=None):
    """
    Load a model from either a pretrained path or a custom checkpoint with flexible configuration.
    
    This function provides a unified interface for loading models from different sources:
    pretrained models from repositories (like Hugging Face Hub) or custom checkpoints with
    embedded configurations. It automatically selects the appropriate loading strategy based
    on the provided parameters.
    
    Args:
        model_path (str, optional): Path or identifier for pretrained model loading.
                                   Can be a local directory path or Hugging Face model identifier
                                   (e.g., "openai/clip-vit-base-patch32"). Mutually exclusive with ckpt_path.
        device (str, optional): Target device for model placement (e.g., "cuda:0", "cpu").
                               Currently not utilized in the implementation but reserved for future use.
        ckpt_path (str, optional): Path to custom PyTorch checkpoint file (.pth, .pt).
                                  The checkpoint should contain 'cfg' and 'model' keys.
                                  Mutually exclusive with model_path.
        model_class (class): Model class for instantiation. Must have either `from_pretrained`
                           class method (for pretrained loading) or accept config in constructor
                           (for checkpoint loading).
        config_class (class, optional): Configuration class for checkpoint loading.
                                       Required when loading from checkpoint to reconstruct
                                       model configuration from saved 'cfg' data.
        cache_dir (str, optional): Directory for caching downloaded pretrained models.
                                  Only used for pretrained model loading to avoid re-downloading.
    
    Returns:
        torch.nn.Module: Loaded model instance ready for inference or further training.
                        The model will be in evaluation mode by default.
    
    Raises:
        ValueError: If neither model_path nor ckpt_path is provided, or if both are provided.
        FileNotFoundError: If the specified checkpoint or model path does not exist.
        KeyError: If checkpoint doesn't contain required 'cfg' or 'model' keys.
        AttributeError: If model_class doesn't have required methods for the loading strategy.
    
    Example:
        # Load pretrained model
        >>> from transformers import CLIPModel, CLIPConfig
        >>> model = load_model(
        ...     model_path="openai/clip-vit-base-patch32",
        ...     model_class=CLIPModel,
        ...     cache_dir="./cache"
        ... )
        
        # Load custom checkpoint
        >>> model = load_model(
        ...     ckpt_path="./checkpoints/best_model.pth",
        ...     model_class=CustomEncoder,
        ...     config_class=CustomConfig
        ... )
    
    Note:
        - The device parameter is accepted but not currently used in model placement
        - Progress descriptions are currently not displayed due to unused tqdm integration
        - Checkpoint loading assumes 'cfg' and 'model' keys in the saved state dict
    """
    
    def load_model_with_progress(model_path, model_class, cache_dir=None, desc='Loading model'):
        """
        Load a pretrained model with optional progress tracking.
        
        Internal helper function for loading models from pretrained repositories
        using the model class's from_pretrained method.
        
        Args:
            model_path (str): Path or identifier for the pretrained model
            model_class (class): Model class with from_pretrained class method
            cache_dir (str, optional): Cache directory for model storage
            desc (str): Description for progress tracking (currently unused)
        
        Returns:
            torch.nn.Module: Loaded pretrained model instance
        """
        # Load model using the class's from_pretrained method
        model = model_class.from_pretrained(model_path, cache_dir=cache_dir)
        return model
    
    def load_ckpt_with_progress(checkpoint_path, model_class, config_class, desc="Loading Checkpoint"):
        """
        Load a model from a custom checkpoint with configuration restoration.
        
        Internal helper function for loading models from custom PyTorch checkpoints
        that contain both model weights and configuration data.
        
        Args:
            checkpoint_path (str): Path to the checkpoint file
            model_class (class): Model class for instantiation
            config_class (class): Configuration class for config reconstruction
            desc (str): Description for progress tracking (currently unused)
        
        Returns:
            torch.nn.Module: Loaded model instance with restored weights
        """
        # Load checkpoint data from file (CPU mapping for compatibility)
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        
        # Reconstruct model configuration from saved config data
        model_cfg = config_class(checkpoint['cfg'])
        
        # Instantiate model with reconstructed configuration
        model = model_class(model_cfg)
        
        # Load saved model weights into the instantiated model
        model.load_state_dict(checkpoint['model'])
        
        return model
    
    # Determine loading strategy based on provided parameters
    if model_path:
        # Load from pretrained model repository
        model = load_model_with_progress(model_path, model_class, cache_dir, device)
    elif ckpt_path:
        # Load from custom checkpoint
        model = load_ckpt_with_progress(ckpt_path, model_class, config_class, device)
    else:
        # Neither loading method specified - raise error
        raise ValueError('Either model_path or ckpt_path must be provided')
    
    return model

