import torch
import numpy as np
from model_utils import MaskedResNet, MaskedViT
from config import Config

class CAVManager:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = config.DEVICE

    def extract_embedding(self, image, mask):
        """
        Extracts embedding based on architecture.
        """
        image = image.to(self.device)
        mask = mask.to(self.device)
        
        with torch.no_grad():
            if isinstance(self.model, MaskedResNet):
                # CNN 
                _, feats = self.model(image, mask)
                embedding = feats.mean(dim=[2, 3]) 
                
            elif isinstance(self.model, MaskedViT):
                # ViT 
                embedding = self.model.get_patch_embeddings(image, mask) # [B, D]
                
        return embedding

    def create_cav(self, concept_name, positive_samples, negative_samples):
        """
        Creates a CAV using Centroid Difference.
        """
        pos_embeddings = []
        neg_embeddings = []
        
        for img, mask in positive_samples:
            emb = self.extract_embedding(img.unsqueeze(0), mask.unsqueeze(0))
            pos_embeddings.append(emb)
            
        for img, mask in negative_samples:
            emb = self.extract_embedding(img.unsqueeze(0), mask.unsqueeze(0))
            neg_embeddings.append(emb)
            
        if not pos_embeddings or not neg_embeddings:
            print(f"Skipping CAV for {concept_name}: Insufficient data.")
            return None

        pos_mat = torch.cat(pos_embeddings, dim=0)
        neg_mat = torch.cat(neg_embeddings, dim=0)
        
        mu_pos = pos_mat.mean(dim=0)
        mu_neg = neg_mat.mean(dim=0)
        
        cav = mu_pos - mu_neg
        
        return {
            "vector": cav,
            "mu_pos": mu_pos,
            "mu_neg": mu_neg,
            "concept": concept_name
        }