# concept_alignment/models/property_predictor.py
import torch
import torch.nn as nn
import logging

logger = logging.getLogger(__name__)

class PropertyPredictor(nn.Module):
    """
    Property prediction model that takes concept vectors as input
    and predicts property values.
    """
    def __init__(self, num_concepts, concept_dim, num_properties):
        """
        Initialize the property predictor.
        
        Args:
            num_concepts: Number of concepts from the CBM
            concept_dim: Dimension of each concept vector
            num_properties: Number of properties to predict
        """
        super().__init__()
        
        self.num_concepts = num_concepts
        self.num_properties = num_properties
        
        # Create an MLP for each property
        self.property_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(concept_dim, concept_dim*2),
                nn.LayerNorm(concept_dim*2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(concept_dim*2, concept_dim),
                nn.ReLU(),
                nn.Linear(concept_dim, concept_dim//2),
                nn.ReLU(),
                nn.Linear(concept_dim//2, 1)
            ) for _ in range(num_properties)
        ])

        logger.info(f"Initialized PropertyPredictor with {num_properties} property MLPs")
    
    def forward(self, concept_vectors):
        """
        Forward pass to predict properties from concept vectors.
        
        Args:
            concept_vectors: Concept vectors from CBM [batch_size, num_concepts, concept_dim]
            
        Returns:
            property_predictions: Predicted property values [batch_size, num_properties]
        """
        batch_size = concept_vectors.size(0)
        property_predictions = torch.zeros(batch_size, self.num_properties, device=concept_vectors.device)
        
        # For each property, use the corresponding concept
        for i in range(min(self.num_concepts, self.num_properties)):
            # Extract the concept vector for this property
            concept_vec = concept_vectors[:, i, :] 
            
            # Predict property using the MLP
            pred = self.property_mlps[i](concept_vec)
            property_predictions[:, i] = pred.squeeze(-1)
        
        return property_predictions