"""
Base connector module for multimodal feature fusion.

This module provides the abstract base class for different types of connectors
that fuse features from multiple modalities (audio, visual, text) before feeding
them to the language model.
"""

import torch
from torch import nn


class BaseConnector(nn.Module):
    """
    Abstract base class for multimodal feature connectors.
    
    This class defines the interface for modules that combine features
    from different modalities into a unified representation suitable
    for language model processing.
    """
    
    def __init__(self, connector_configs, num_modality, device):
        """
        Initialize the base connector.
        
        Args:
            connector_configs (dict): Configuration parameters for the connector
            num_modality (int): Number of input modalities to handle
            device (torch.device or str): Device for computation
        """
        super(BaseConnector, self).__init__()
        self.connector_configs = connector_configs
        self.num_modality = num_modality
        self.device = device
        
    def forward(self, hidden_states):
        """
        Forward pass to fuse multimodal features.
        
        This method should be implemented by subclasses to define
        the specific fusion strategy.
        
        Args:
            hidden_states (list): List of feature tensors from different modalities
            
        Returns:
            torch.Tensor: Fused multimodal features
            
        Raises:
            NotImplementedError: This method must be implemented by subclasses
        """
        raise NotImplementedError("Subclasses must implement the forward method")


