# concept_alignment/models/cbm.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import math

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from concept_alignment.models.loss import NMSE

logger = logging.getLogger(__name__)

class CBMLightningModule(pl.LightningModule):
    """PyTorch Lightning module for Concept Bottleneck Model training"""
    
    def __init__(self, cbm, property_predictor, learning_rate, lambda_ortho, property_cols=None):
        """
        Initialize the Lightning module.
        
        Args:
            cbm: Instance of ConceptBottleneckModel
            property_predictor: Instance of PropertyPredictor
            learning_rate: Learning rate for optimizer
            lambda_ortho: Weight for orthogonality loss
            property_cols: Names of property columns (for logging)
        """
        super().__init__()
        self.cbm = cbm
        self.property_predictor = property_predictor
        self.learning_rate = learning_rate
        self.lambda_ortho = lambda_ortho
        self.property_cols = property_cols
        
        # Register the loss function
        self.nmse_loss = NMSE()
        
        # Save hyperparameters
        self.save_hyperparameters(ignore=['cbm', 'property_predictor'])
    
    def forward(self, representations):
        """Forward pass through both models"""
        concept_values, concept_vectors = self.cbm(representations)
        property_predictions = self.property_predictor(concept_vectors)
        return property_predictions, concept_values, concept_vectors
    
    def training_step(self, batch, batch_idx):
        """Training step"""
        representations, targets = batch
        
        # Forward pass
        property_predictions, concept_values, concept_vectors = self(representations)
        
        # Compute losses
        prop_loss = self.nmse_loss(property_predictions, targets)
        ortho_loss = self.cbm.compute_orthogonality_loss()
        total_loss = prop_loss + self.lambda_ortho * ortho_loss
        
        # Log metrics
        self.log('train/total_loss', total_loss, prog_bar=True)
        self.log('train/property_loss', prop_loss, prog_bar=False)
        self.log('train/ortho_loss', ortho_loss, prog_bar=True)

        # Log average concept values for this batch
        avg_concept_values = torch.mean(concept_values, dim=0)  # Average across batch
        std_concept_values = torch.std(concept_values, dim=0)   # Std dev across batch
        for i in range(avg_concept_values.shape[0]):
            self.log(f'train/concept_{i}_avg', avg_concept_values[i], prog_bar=True)
            self.log(f'train/concept_{i}_std', std_concept_values[i], prog_bar=True)
        
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step"""
        representations, targets = batch
        
        # Forward pass
        property_predictions, concept_values, concept_vectors = self(representations)
        
        # Compute losses
        prop_loss = self.nmse_loss(property_predictions, targets)
        ortho_loss = self.cbm.compute_orthogonality_loss()
        total_loss = prop_loss + self.lambda_ortho * ortho_loss
        
        # Log metrics
        self.log('val/total_loss', total_loss, prog_bar=True)
        self.log('val/property_loss', prop_loss, prog_bar=False)
        self.log('val/ortho_loss', ortho_loss, prog_bar=True)

        # Log average concept values for this batch
        avg_concept_values = torch.mean(concept_values, dim=0)  # Average across batch
        std_concept_values = torch.std(concept_values, dim=0)   # Std dev across batch
        for i in range(avg_concept_values.shape[0]):
            self.log(f'train/concept_{i}_avg', avg_concept_values[i], prog_bar=True)
            self.log(f'train/concept_{i}_std', std_concept_values[i], prog_bar=True)
        
        # Calculate per-property MSE errors
        for i in range(targets.shape[1]):
            prop_error = ((property_predictions[:, i] - targets[:, i]) ** 2).mean()
            prop_name = f"val/property_{i}_mse"
            if self.property_cols and i < len(self.property_cols):
                prop_name = f"val/{self.property_cols[i]}_mse"
            self.log(prop_name, prop_error, prog_bar=True)
        
        return total_loss
    
    def test_step(self, batch, batch_idx):
        """Test step"""
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        """Configure optimizer"""
        
        optimizer = torch.optim.AdamW([
            {'params': self.cbm.parameters(), 'weight_decay': 0.1},
            {'params': self.property_predictor.parameters(), 'weight_decay': 0.01}
        ], lr=self.learning_rate, weight_decay=0.01)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val/property_loss'
        }
    
    def on_save_checkpoint(self, checkpoint):
        """Handle custom behavior when saving checkpoint"""
        checkpoint['property_cols'] = self.property_cols
    
    def on_load_checkpoint(self, checkpoint):
        """Handle custom behavior when loading checkpoint"""
        self.property_cols = checkpoint.get('property_cols', self.property_cols)


class ConceptBottleneckModel(nn.Module):
    """
    Concept Bottleneck Model for learning concepts from LLM representations.
    The model takes hidden representations from an LLM and outputs concept values.
    """
    def __init__(self, hidden_dim, num_concepts, concept_dim):
        """
        Initialize the Concept Bottleneck Model.
        
        Args:
            hidden_dim: Dimension of the input hidden representation
            num_concepts: Number of concepts to extract
            concept_dim: Dimension of each concept embedding
        """
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_concepts = num_concepts
        self.concept_dim = hidden_dim
        
        # CBM to extract concept values from hidden representations
        self.cbm = nn.Sequential(
                    nn.Linear(hidden_dim, 4*hidden_dim), 
                    nn.LayerNorm(4*hidden_dim),  
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(4*hidden_dim, 2*hidden_dim), 
                    nn.ReLU(),
                    nn.Linear(2*hidden_dim, num_concepts),
                    nn.Sigmoid()
                )
        
        # Learnable concept embeddings
        self.concept_embeddings = nn.Parameter(torch.randn(num_concepts, concept_dim))
        nn.init.kaiming_uniform_(self.concept_embeddings, a=math.sqrt(5))
        
        logger.info(f"Initialized CBM with {num_concepts} concepts, each with dimension {concept_dim}")
    
    def forward(self, h_n):
        """
        Forward pass through the CBM.
        
        Args:
            h_n: Hidden representation of the last token [batch_size, hidden_dim]
            
        Returns:
            concept_values: Activation values for each concept [batch_size, num_concepts]
            concept_vectors: Concept vectors scaled by their activation [batch_size, num_concepts, concept_dim]
        """

        # Extract concept values
        concept_values = self.cbm(h_n)  # [batch_size, num_concepts]
        
        # Reshape for broadcasting
        concept_values_expanded = concept_values.unsqueeze(-1)  # [batch_size, num_concepts, 1]
        
        # Scale concept embeddings by their activation values
        concept_vectors = concept_values_expanded * self.concept_embeddings.unsqueeze(0)
        
        return concept_values, concept_vectors
    
    def compute_orthogonality_loss(self):
        """
        Compute orthogonality loss for concept embeddings.
        This encourages concept embeddings to be orthogonal to each other.
        
        Returns:
            ortho_loss: Orthogonality loss value
        """
        # Normalize embeddings
        normalized_embeddings = F.normalize(self.concept_embeddings, p=2, dim=1)
        
        # Compute pairwise dot products
        dot_products = torch.matmul(normalized_embeddings, normalized_embeddings.T)
        
        # Create a mask to zero out the diagonal
        mask = torch.ones_like(dot_products) - torch.eye(dot_products.shape[0], device=dot_products.device)
        
        # Compute the loss as the sum of squared dot products (excluding diagonal)
        ortho_loss = torch.sum((dot_products * mask) ** 2)
        
        return ortho_loss

    def compute_activation_loss(self, concept_values):
        """Encourage concept values to be more decisive (closer to 0 or 1)"""
        return -torch.mean(torch.abs(concept_values - 0.5)) + 0.5