import torch
from typing import Optional
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm
import numpy as np
from collections import Counter

@torch.no_grad()
def get_accuracy(logits: torch.Tensor, y_true: torch.Tensor, mode: str, idk_class: int, binary_classification: Optional[bool] = None
) -> float:
    """Calculate accuracy based on the mode (merlin or morgana) and classification type.
    
    Args:
        logits: Model output logits
        y_true: Ground truth labels
        mode: Either "merlin" or "morgana"
        idk_class: Class index for IDK (I Don't Know) predictions
        binary_classification: Whether the task is binary classification
    Returns:
        float: Accuracy score
    """
    if binary_classification is True:
        prediction = torch.round(torch.sigmoid(logits)).squeeze()  # convert probabilities to binary predictions
    else:
        prediction = torch.argmax(logits, dim=1)
    if mode == "merlin":
        accuracy = prediction.eq(y_true.squeeze()).sum().item() / float(len(y_true))
    elif mode == "morgana" and binary_classification is not True:
        accuracy = torch.logical_or(
            prediction.eq(y_true.squeeze()), 
            prediction.eq(idk_class)
        ).sum().item() / float(len(y_true))
    elif mode == "morgana" and binary_classification is True:
        accuracy = prediction.eq(y_true.squeeze()).sum().item() / float(len(y_true))
    else:
        raise ValueError(f"Unexpected value for mode, got `{mode}`")
    return accuracy

def plot_confusion_matrix(matrix, class_names, title, save_path=None):
    """Create a confusion matrix plot
    
    Args:
        matrix: numpy array of confusion matrix values
        class_names: list of class names for labels
        title: title for the plot
        save_path: optional path to save the plot as PDF and SVG
    Returns:
        matplotlib figure
    """
    # Increase default font sizes and set font family
    plt.rcParams.update({
        'font.size': 32,          # Base font size
        'axes.titlesize': 36,     # Title font size
        'axes.labelsize': 34,     # Axis label size
        'xtick.labelsize': 30,    # X-tick label size
        'ytick.labelsize': 30,    # Y-tick label size
        'font.family': 'serif',   # Use serif font family
        'font.serif': ['Computer Modern Roman', 'Times New Roman', 'DejaVu Serif']  # Prefer LaTeX-like fonts
    })
    
    # Create figure with larger size
    plt.figure(figsize=(12, 10))
    
    # Plot confusion matrix
    sns.heatmap(
        matrix, 
        annot=True, 
        fmt='.1f', 
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        annot_kws={'size': 30}  # Size of annotation text
    )
    
    plt.title(title, pad=30)  # Add padding to prevent title overlap
    plt.xlabel('Predicted Label', labelpad=20)  # Add padding to x-label
    plt.ylabel('True Label', labelpad=20)       # Add padding to y-label
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save plot if path is provided
    if save_path is not None:
        plt.savefig(f"{save_path}.pdf", format='pdf', bbox_inches='tight', dpi=300)
        plt.savefig(f"{save_path}.svg", format='svg', bbox_inches='tight', dpi=300)
    
    return plt.gcf()

def plot_histogram(feature_counts: dict, title: str, feature_interpretations: dict = None):
    """Create a histogram plot of feature distributions
    
    Args:
        feature_counts: Dictionary mapping features to their counts
        title: Title for the plot
        feature_interpretations: Optional dictionary mapping feature tuples to their interpretations
    Returns:
        matplotlib figure
    """
    # Create figure with higher resolution
    fig = plt.figure(figsize=(16, 8), dpi=300)
    
    # Sort features by count and take top 50
    sorted_features = sorted(feature_counts.items(), key=lambda x: x[1], reverse=True)
    top_features = sorted_features[:50]
    features, counts = zip(*top_features)
    
    # Create bar plot with improved styling
    plt.bar(range(len(counts)), counts, color='#2978A0', alpha=0.8)
    
    # Customize title and labels with better spacing
    plt.title(title, pad=20, fontsize=14, fontweight='bold')
    plt.xlabel('Feature : (Block, Concept)', labelpad=15, fontsize=12)
    plt.ylabel('Count', labelpad=15, fontsize=12)
    
    # Format x-axis labels with interpretations if provided
    if feature_interpretations:
        labels = []
        for f in features:
            if f in feature_interpretations:
                labels.append(f"{feature_interpretations[f]} : {str(f)}")
            else:
                labels.append(str(f))
    else:
        labels = [str(f) for f in features]
    
    # Rotate and align the tick labels so they look better
    plt.xticks(range(len(labels)), labels, rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=10)
    
    # Add grid for better readability
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    return fig

def compute_confusion_matrix(loader: DataLoader, set_name: str, approach: str, model, merlin=None, morgana=None, device=None, num_classes=None, logger=None, save_path=None):
    """Dispatch to appropriate confusion matrix computation based on approach"""
    if approach == "sfw":
        compute_confusion_matrix_sfw(loader, set_name, model, merlin, morgana, device, num_classes, logger, save_path)
    elif approach == "regular":
        compute_confusion_matrix_regular(loader, set_name, model, device, num_classes, logger, save_path)
    elif approach == "learn_fs":
        compute_confusion_matrix_learn_fs(loader, set_name, model, merlin, morgana, device, num_classes, logger, save_path)
    else:
        raise ValueError(f"Approach {approach} not supported for confusion matrix computation")

def compute_confusion_matrix_regular(loader: DataLoader, set_name: str, model, device, num_classes: int, logger=None, save_path=None):
    """Compute and display confusion matrix for regular approach"""
    model.eval()
    confusion_matrix = torch.zeros(num_classes, num_classes)
    
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            # Update confusion matrix
            for t, p in zip(targets.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

    # Convert counts to percentages
    confusion_matrix_percent = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True) * 100

    # Print confusion matrix with better alignment
    print(f"{set_name} Confusion Matrix (percentages):")
    print("-" * 50)
    print("True\\Pred".ljust(15), end="")
    for i in range(num_classes):
        print(f"Class {i}".ljust(8), end="")
    print("\n" + "-" * 50)
    
    for i in range(num_classes):
        print(f"Class {i}".ljust(15), end="")
        for j in range(num_classes):
            print(f"{confusion_matrix_percent[i, j]:.1f}%".ljust(8), end="")
        print(f"({int(confusion_matrix[i].sum())} samples)")

    # Calculate per-class accuracy
    per_class_acc = {}
    print("\nPer-class accuracy:")
    for i in range(num_classes):
        class_acc = confusion_matrix[i, i] / confusion_matrix[i].sum() * 100
        per_class_acc[f"Class {i}"] = class_acc.item()
        print(f"Class {i}: {class_acc:.1f}%")
    print("-" * 50)

    # Log to wandb if enabled
    if logger is not None or save_path is not None:
        # Create class names for the confusion matrix
        class_names = [f"Class {i}" for i in range(num_classes)]
        
        # Create confusion matrix plot
        fig = plot_confusion_matrix(
            matrix=confusion_matrix_percent.cpu().numpy(),
            class_names=class_names,
            title=f"{set_name} Confusion Matrix",
            save_path=save_path
        )
        
        # Log to wandb if enabled
        if logger is not None:
            log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
            logger.log({
                f"{log_prefix}/confusion_matrix": wandb.Image(fig)
            })
        
        # Close the figure to free memory
        plt.close(fig)

def compute_confusion_matrix_sfw(loader: DataLoader, set_name: str, model, merlin, morgana, device, num_classes: int, logger=None, save_path=None):
    """Compute confusion matrices for SFW approach"""
    model.eval()
    # Include IDK class for both Merlin and Morgana
    num_classes_with_idk = num_classes + 1
    
    # Initialize confusion matrices
    merlin_confusion = torch.zeros(num_classes, num_classes_with_idk)
    morgana_confusion = torch.zeros(num_classes, num_classes_with_idk)
    
    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device).long()

        # Temporarily enable gradients for mask optimization
        with torch.enable_grad():
            continuous_mask_merlin = merlin(inputs, targets, model)
            continuous_mask_morgana = morgana(inputs, targets, model)
        
        # Apply masks and get predictions
        with torch.no_grad():         
            binary_mask_merlin = merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = morgana.get_binary_mask(continuous_mask_morgana)
            
            masked_inputs_merlin = merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = morgana.apply_mask(inputs, binary_mask_morgana)
            
            merlin_logits = model(masked_inputs_merlin)
            morgana_logits = model(masked_inputs_morgana)
            
            _, merlin_predicted = merlin_logits.max(1)
            _, morgana_predicted = morgana_logits.max(1)
            
            # Update confusion matrices
            for t, mp, mop in zip(targets.view(-1), merlin_predicted.view(-1), morgana_predicted.view(-1)):
                merlin_confusion[t.long(), mp.long()] += 1
                morgana_confusion[t.long(), mop.long()] += 1

    # Convert counts to percentages
    merlin_percent = merlin_confusion / merlin_confusion.sum(dim=1, keepdim=True) * 100
    morgana_percent = morgana_confusion / morgana_confusion.sum(dim=1, keepdim=True) * 100

    # Create class names
    class_names = [f"Class {i}" for i in range(num_classes)] + ["IDK"]

    # Print Merlin confusion matrix (Completeness)
    print(f"\n{set_name} Confusion Matrix - Merlin (Completeness):")
    print("-" * 68)
    print("True\\Pred".ljust(15), end="")
    for i in range(num_classes_with_idk):
        label = "IDK" if i == num_classes else f"Class {i}"
        print(f"{label}".ljust(8), end="")
    print("\n" + "-" * 68)
    
    for i in range(num_classes):
        print(f"Class {i}".ljust(15), end="")
        for j in range(num_classes_with_idk):
            print(f"{merlin_percent[i, j]:.1f}%".ljust(8), end="")
        print(f"({int(merlin_confusion[i].sum())} samples)")

    # Print Morgana confusion matrix (Soundness)
    print(f"\n{set_name} Confusion Matrix - Morgana (Soundness):")
    print("-" * 68)
    print("True\\Pred".ljust(15), end="")
    for i in range(num_classes_with_idk):
        label = "IDK" if i == num_classes else f"Class {i}"
        print(f"{label}".ljust(8), end="")
    print("\n" + "-" * 68)
    
    for i in range(num_classes):
        print(f"Class {i}".ljust(15), end="")
        for j in range(num_classes_with_idk):
            print(f"{morgana_percent[i, j]:.1f}%".ljust(8), end="")
        print(f"({int(morgana_confusion[i].sum())} samples)")

    # Calculate per-class metrics
    merlin_per_class = {}
    morgana_per_class = {}
    print("\nPer-class metrics:")
    for i in range(num_classes):
        # Merlin accuracy (completeness)
        merlin_acc = merlin_confusion[i, i] / merlin_confusion[i].sum() * 100
        merlin_per_class[f"Class {i}"] = merlin_acc.item()
        
        # Morgana accuracy (soundness) - correct if predicted class matches or IDK
        morgana_acc = (morgana_confusion[i, i] + morgana_confusion[i, -1]) / morgana_confusion[i].sum() * 100
        morgana_per_class[f"Class {i}"] = morgana_acc.item()
        
        print(f"Class {i}:")
        print(f"  Completeness: {merlin_acc:.1f}%")
        print(f"  Soundness: {morgana_acc:.1f}%")
    print("-" * 60)

    # Log to wandb if enabled
    if logger is not None or save_path is not None:
        # Create confusion matrix plots
        log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
        
        # Merlin (Completeness) confusion matrix
        fig_merlin = plot_confusion_matrix(
            matrix=merlin_percent.cpu().numpy(),
            class_names=class_names,
            title=f"{set_name} Confusion Matrix - Merlin (Completeness)",
            save_path=None if save_path is None else f"{save_path}_merlin.pdf"
        )
        
        # Morgana (Soundness) confusion matrix
        fig_morgana = plot_confusion_matrix(
            matrix=morgana_percent.cpu().numpy(),
            class_names=class_names,
            title=f"{set_name} Confusion Matrix - Morgana (Soundness)",
            save_path=None if save_path is None else f"{save_path}_morgana.pdf"
        )
        
        # Log to wandb if enabled
        if logger is not None:
            logger.log({
                f"{log_prefix}/confusion_matrix_merlin": wandb.Image(fig_merlin),
                f"{log_prefix}/confusion_matrix_morgana": wandb.Image(fig_morgana)
            })
        
        # Close figures to free memory
        plt.close(fig_merlin)
        plt.close(fig_morgana)

def compute_confusion_matrix_learn_fs(loader: DataLoader, set_name: str, model, merlin, morgana, device, num_classes: int, logger=None, save_path=None):
    """Compute confusion matrices for learnable feature selectors"""
    model.eval()
    merlin.eval()
    morgana.eval()
    
    # Include IDK class for both Merlin and Morgana
    num_classes_with_idk = num_classes + 1
    
    # Initialize confusion matrices
    merlin_confusion = torch.zeros(num_classes, num_classes_with_idk)
    morgana_confusion = torch.zeros(num_classes, num_classes_with_idk)
    
    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device).long()

        with torch.no_grad():
            # Get masks
            continuous_mask_merlin = merlin(inputs)
            continuous_mask_morgana = morgana(inputs)
            
            # Convert to binary masks using top-k selection
            binary_mask_merlin = merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = morgana.get_binary_mask(continuous_mask_morgana)
        
            # Apply masks and get predictions
            masked_inputs_merlin = merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = morgana.apply_mask(inputs, binary_mask_morgana)
            
            merlin_logits = model(masked_inputs_merlin)
            morgana_logits = model(masked_inputs_morgana)
            
            _, merlin_predicted = merlin_logits.max(1)
            _, morgana_predicted = morgana_logits.max(1)
            
            # Update confusion matrices
            for t, mp, mop in zip(targets.view(-1), merlin_predicted.view(-1), morgana_predicted.view(-1)):
                merlin_confusion[t.long(), mp.long()] += 1
                morgana_confusion[t.long(), mop.long()] += 1

    # Convert counts to percentages
    merlin_percent = merlin_confusion / merlin_confusion.sum(dim=1, keepdim=True) * 100
    morgana_percent = morgana_confusion / morgana_confusion.sum(dim=1, keepdim=True) * 100

    # Create class names
    class_names = [f"Class {i}" for i in range(num_classes)] + ["IDK"]

    # Print Merlin confusion matrix (Completeness)
    print(f"\n{set_name} Confusion Matrix - Merlin (Completeness):")
    print("-" * 68)
    print("True\\Pred".ljust(15), end="")
    for i in range(num_classes_with_idk):
        label = "IDK" if i == num_classes else f"Class {i}"
        print(f"{label}".ljust(8), end="")
    print("\n" + "-" * 68)
    
    for i in range(num_classes):
        print(f"Class {i}".ljust(15), end="")
        for j in range(num_classes_with_idk):
            print(f"{merlin_percent[i, j]:.1f}%".ljust(8), end="")
        print(f"({int(merlin_confusion[i].sum())} samples)")

    # Print Morgana confusion matrix (Soundness)
    print(f"\n{set_name} Confusion Matrix - Morgana (Soundness):")
    print("-" * 68)
    print("True\\Pred".ljust(15), end="")
    for i in range(num_classes_with_idk):
        label = "IDK" if i == num_classes else f"Class {i}"
        print(f"{label}".ljust(8), end="")
    print("\n" + "-" * 68)
    
    for i in range(num_classes):
        print(f"Class {i}".ljust(15), end="")
        for j in range(num_classes_with_idk):
            print(f"{morgana_percent[i, j]:.1f}%".ljust(8), end="")
        print(f"({int(morgana_confusion[i].sum())} samples)")

    # Calculate per-class metrics
    merlin_per_class = {}
    morgana_per_class = {}
    print("\nPer-class metrics:")
    for i in range(num_classes):
        # Merlin accuracy (completeness)
        merlin_acc = merlin_confusion[i, i] / merlin_confusion[i].sum() * 100
        merlin_per_class[f"Class {i}"] = merlin_acc.item()
        
        # Morgana accuracy (soundness) - correct if predicted class matches or IDK
        morgana_acc = (morgana_confusion[i, i] + morgana_confusion[i, -1]) / morgana_confusion[i].sum() * 100
        morgana_per_class[f"Class {i}"] = morgana_acc.item()
        
        print(f"Class {i}:")
        print(f"  Completeness: {merlin_acc:.1f}%")
        print(f"  Soundness: {morgana_acc:.1f}%")
    print("-" * 60)

    # Log to wandb if enabled
    if logger is not None or save_path is not None:
        # Create confusion matrix plots
        log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
        
        # Merlin (Completeness) confusion matrix
        fig_merlin = plot_confusion_matrix(
            matrix=merlin_percent.cpu().numpy(),
            class_names=class_names,
            title=f"{set_name} Confusion Matrix - Merlin (Completeness)",
            save_path=None if save_path is None else f"{save_path}_merlin.pdf"
        )
        
        # Morgana (Soundness) confusion matrix
        fig_morgana = plot_confusion_matrix(
            matrix=morgana_percent.cpu().numpy(),
            class_names=class_names,
            title=f"{set_name} Confusion Matrix - Morgana (Soundness)",
            save_path=None if save_path is None else f"{save_path}_morgana.pdf"
        )
        
        # Log to wandb if enabled
        if logger is not None:
            logger.log({
                f"{log_prefix}/confusion_matrix_merlin": wandb.Image(fig_merlin),
                f"{log_prefix}/confusion_matrix_morgana": wandb.Image(fig_morgana)
            })
        
        # Close figures to free memory
        plt.close(fig_merlin)
        plt.close(fig_morgana)

def compute_feature_distribution(loader: DataLoader, set_name: str, model, merlin, morgana, device, num_classes: int, 
                               num_slots: int, num_blocks: int, mask_size: int, enc_type: str, approach: str, logger=None, feature_interpretations=None):
    """Compute and visualize feature distribution per class"""
    if enc_type != 'one_hot_padded':
        raise ValueError(f"Feature distribution computation is only supported for 'one_hot_padded' encoding type, got {enc_type}")

    if approach not in ["sfw", "learn_fs"]:
        raise ValueError(f"Feature distribution computation not supported for approach {approach}")

    model.eval()
    merlin.eval()
    morgana.eval()

    # Initialize per-class counters
    total_feature_counts_merlin_per_class = defaultdict(lambda: defaultdict(int))
    total_feature_counts_morgana_per_class = defaultdict(lambda: defaultdict(int))

    for inputs, targets in tqdm(loader, desc=f'Computing feature distribution for {set_name}'):
        inputs = inputs.to(device)
        targets = targets.to(device).long()

        if approach == "sfw":
            # SFW needs gradients for mask optimization
            with torch.enable_grad():
                continuous_mask_merlin = merlin(inputs, targets, model)
                continuous_mask_morgana = morgana(inputs, targets, model)
        else:  # learn_fs
            with torch.no_grad():
                continuous_mask_merlin = merlin(inputs)
                continuous_mask_morgana = morgana(inputs)

        # Rest of processing is the same for both approaches
        with torch.no_grad():
            binary_mask_merlin = merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = morgana.get_binary_mask(continuous_mask_morgana)

            # Apply masks and get predictions
            masked_inputs_merlin = merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = morgana.apply_mask(inputs, binary_mask_morgana)

            # Reshape and analyze features
            enc_masked_merlin = masked_inputs_merlin.view(-1, num_slots, num_blocks, masked_inputs_merlin.shape[-1]//num_blocks)
            enc_masked_morgana = masked_inputs_morgana.view(-1, num_slots, num_blocks, masked_inputs_morgana.shape[-1]//num_blocks)

            has_one_merlin = torch.nonzero(enc_masked_merlin, as_tuple=False)
            has_one_morgana = torch.nonzero(enc_masked_morgana, as_tuple=False)

            has_one_merlin = has_one_merlin.view(-1, mask_size, 4)
            has_one_morgana = has_one_morgana.view(-1, mask_size, 4)

            selected_features_merlin_per_image = has_one_merlin[:, :, 2:]
            selected_features_morgana_per_image = has_one_morgana[:, :, 2:]

            # Update per-class counts
            for label, features in zip(targets, selected_features_merlin_per_image):
                for feature in features:
                    total_feature_counts_merlin_per_class[label.item()][tuple(feature.tolist())] += 1
            for label, features in zip(targets, selected_features_morgana_per_image):
                for feature in features:
                    total_feature_counts_morgana_per_class[label.item()][tuple(feature.tolist())] += 1

    # Log to wandb if enabled
    if logger is not None:
        # Convert 'validation' to 'val' for folder name
        log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
        
        # Per-class distribution plots
        for class_label in total_feature_counts_merlin_per_class.keys():
            merlin_class_fig = plot_histogram(
                total_feature_counts_merlin_per_class[class_label],
                f"{set_name} Feature Distribution - Merlin - Class {class_label}",
                feature_interpretations=feature_interpretations
            )
            morgana_class_fig = plot_histogram(
                total_feature_counts_morgana_per_class[class_label],
                f"{set_name} Feature Distribution - Morgana - Class {class_label}",
                feature_interpretations=feature_interpretations
            )
            
            logger.log({
                f"{log_prefix}/feature_distribution/merlin_class_{class_label}": wandb.Image(merlin_class_fig),
                f"{log_prefix}/feature_distribution/morgana_class_{class_label}": wandb.Image(morgana_class_fig)
            })

            plt.close(merlin_class_fig)
            plt.close(morgana_class_fig)

def compute_precision_and_entropy(loader: DataLoader, target_class: int, tolerance: float, set_name: str,
                                model, merlin, morgana, device, num_classes: int, batch_size: int, 
                                num_workers: int, seed: int, approach: str, logger=None):
    """Compute average precision and conditional entropy"""
    model.eval()
    merlin.eval()
    morgana.eval()

    # Create inner loader by shuffling the dataset
    rng_state = torch.get_rng_state()
    torch.manual_seed(seed + 1)
    shuffled_indices = torch.randperm(len(loader.dataset))
    inner_dataset = torch.utils.data.Subset(loader.dataset, shuffled_indices)
    inner_loader = DataLoader(
        inner_dataset,
        batch_size=batch_size,
        shuffle=False,  # Already shuffled through indices
        num_workers=num_workers
    )
    torch.set_rng_state(rng_state)

    average_precision_list = []
    conditional_entropy_list = []
    
    # Progress bar for outer loop
    pbar = tqdm(loader, desc=f"Computing precision and entropy metrics for {set_name}")
    
    for inputs, targets in pbar:
        inputs = inputs.to(device)
        targets = targets.to(device).long()

        # Filter data for target class (Merlin) and other classes (Morgana)
        morgana_target_classes = torch.arange(num_classes).to(device)
        morgana_target_classes = morgana_target_classes[morgana_target_classes != target_class]
        
        inputs_merlin = inputs[targets == target_class]
        targets_merlin = targets[targets == target_class]
        inputs_morgana = inputs[torch.isin(targets, morgana_target_classes)]
        targets_morgana = targets[torch.isin(targets, morgana_target_classes)]

        # Skip if either Merlin or Morgana has no samples
        if len(inputs_merlin) == 0 or len(inputs_morgana) == 0:
            continue

        # Get masks based on approach
        if approach == "sfw":
            with torch.enable_grad():
                continuous_mask_merlin = merlin(inputs_merlin, targets_merlin, model)
                continuous_mask_morgana = morgana(inputs_morgana, targets_morgana, model)
        else:  # learn_fs
            with torch.no_grad():
                continuous_mask_merlin = merlin(inputs_merlin)
                continuous_mask_morgana = morgana(inputs_morgana)
                
        binary_mask_merlin = merlin.get_binary_mask(continuous_mask_merlin)
        binary_mask_morgana = morgana.get_binary_mask(continuous_mask_morgana)

        # Apply masks
        masked_inputs_merlin = merlin.apply_mask(inputs_merlin, binary_mask_merlin)
        masked_inputs_morgana = morgana.apply_mask(inputs_morgana, binary_mask_morgana)

        # Count occurrences per class
        masked_inputs_merlin = masked_inputs_merlin.unsqueeze(1)
        binary_mask_merlin = binary_mask_merlin.unsqueeze(1)
        occurrence_per_class = torch.zeros(binary_mask_merlin.shape[0], num_classes).to(device)

        for inner_inputs, inner_targets in inner_loader:
            inner_inputs = inner_inputs.to(device)
            inner_targets = inner_targets.to(device).long()
            
            inner_inputs = inner_inputs.unsqueeze(0)
            masked_inner_inputs = merlin.apply_mask(inner_inputs, binary_mask_merlin)
            
            # Calculate feature matches using norm difference
            norm_diff = torch.linalg.norm(masked_inputs_merlin - masked_inner_inputs, dim=(2, 3)).squeeze()
            feature_bool_mask = norm_diff <= tolerance
            
            # Count occurrences for each class
            for label in range(num_classes):
                label_bool_mask = (inner_targets == label).unsqueeze(0)
                merged_bool_mask = torch.logical_and(feature_bool_mask, label_bool_mask)
                occurrence_per_class[:, label] += merged_bool_mask.sum(dim=1)

        # Calculate metrics
        total_occurrences = torch.sum(occurrence_per_class, dim=1, keepdim=True)
        # Handle zero occurrences
        zero_idx = (total_occurrences == 0).squeeze()
        total_occurrences[zero_idx] = 0.1
        
        class_probs = occurrence_per_class / total_occurrences
        
        # Calculate batch metrics
        average_precision_batch = torch.gather(class_probs, 1, targets_merlin.unsqueeze(1)).mean()
        conditional_entropy_batch = -torch.xlogy(class_probs, class_probs).sum(dim=1).mean()
        
        average_precision_list.append(average_precision_batch)
        conditional_entropy_list.append(conditional_entropy_batch)
        
        # Update progress bar
        pbar.set_postfix({
            'avg_prec': f'{100*average_precision_batch:.2f}%',
            'cond_ent': f'{conditional_entropy_batch:.4f}'
        })

    # Calculate final metrics
    average_precision = sum(average_precision_list) / len(average_precision_list)
    conditional_entropy = sum(conditional_entropy_list) / len(conditional_entropy_list)

    # Log results
    print(f"\nFinal Metrics for {set_name} Set, Target Class {target_class}:")
    print(f"Average precision: {(100*average_precision):.2f}%")
    print(f"Conditional entropy: {conditional_entropy:.4f}")

    # Log to wandb if enabled
    if logger is not None:
        # Convert 'validation' to 'val' for folder name
        log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
        
        logger.log({
            f'{log_prefix}/class_{target_class}/average_precision': 100 * average_precision,
            f'{log_prefix}/class_{target_class}/conditional_entropy': conditional_entropy
        })

    return average_precision, conditional_entropy

def compute_average_occurrence(loader: DataLoader, set_name: str, model, merlin, morgana, device, num_classes: int, 
                               num_slots: int, num_blocks: int, mask_size: int, enc_type: str, approach: str, logger=None):
    """
    Count exact merlin masks across the entire split and compute average occurrence
    """
    if enc_type != 'one_hot_padded':
        raise ValueError(f"Feature distribution computation is only supported for 'one_hot_padded' encoding type, got {enc_type}")

    if approach not in ["sfw", "learn_fs"]:
        raise ValueError(f"Feature distribution computation not supported for approach {approach}")

    model.eval()
    merlin.eval()
    morgana.eval()

    # collect signatures for the entire split
    all_signatures = []

    for inputs, targets in tqdm(loader, desc=f'Computing average occurrence for {set_name}'):
        inputs = inputs.to(device)
        targets = targets.to(device).long()

        if approach == "sfw":
            # SFW needs gradients for mask optimization
            with torch.enable_grad():
                continuous_mask_merlin = merlin(inputs, targets, model)
                continuous_mask_morgana = morgana(inputs, targets, model)
        else:  # learn_fs
            with torch.no_grad():
                continuous_mask_merlin = merlin(inputs)
                continuous_mask_morgana = morgana(inputs)

        # Rest of processing is the same for both approaches
        with torch.no_grad():
            binary_mask_merlin = merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = morgana.get_binary_mask(continuous_mask_morgana)

            # Apply masks and get predictions
            masked_inputs_merlin = merlin.apply_mask(inputs, binary_mask_merlin)
            masked_inputs_morgana = morgana.apply_mask(inputs, binary_mask_morgana)

            enc_masked_merlin = masked_inputs_merlin.view(-1, num_slots, num_blocks, masked_inputs_merlin.shape[-1]//num_blocks)
            has_one_merlin = torch.nonzero(enc_masked_merlin, as_tuple=False)
            has_one_merlin = has_one_merlin.view(-1, mask_size, 4)

            selected_features_merlin_per_image = has_one_merlin[:, :, 2:] #only blocks and concepts as slot order is random

            for mask in selected_features_merlin_per_image.cpu().tolist():
                sig = tuple(sorted(tuple(pair) for pair in mask))   # sort rows so order doesn’t matter
                all_signatures.append(sig)

    counter = Counter(all_signatures)
    counts = list(counter.values())

    # ---- mean_occurrence (über einzigartige Masken) ----
    mean_occurrence = (sum(counts) / len(counts)) if counts else float('nan')

    # ---- avg_occurrence_per_image (pro Bild, Support exkl. Self) ----
    # For each image's mask: how many other images share it? => counter[sig] - 1
    if all_signatures:
        support_excl_self = np.array([counter[sig] - 1 for sig in all_signatures], dtype=float)
        avg_occurrence_per_image = float(np.mean(support_excl_self))
    else:
        support_excl_self = np.array([], dtype=float)
        avg_occurrence_per_image = float('nan')

    print(f"[{set_name}] mean_occurrence (per unique mask): {mean_occurrence:.2f}")
    print(f"[{set_name}] avg_occurrence_per_image (support excl. self): {avg_occurrence_per_image:.2f}")

    # Log to wandb if enabled
    if logger is not None:
        log_prefix = "val" if set_name.lower() == "validation" else set_name.lower()
        logger.log({
            f'{log_prefix}/mean_occurrence_unique_masks': mean_occurrence,
            f'{log_prefix}/avg_occurrence_per_image': avg_occurrence_per_image,
        })


# Feature interpretations dictionary
FEATURE_INTERPRETATIONS_NCB_SEED_0 = {
    # Block 0 (3 concepts)
    (0, 0): "large, position",
    (0, 1): "small, position",
    (0, 2): "ambiguous",
    
    # Block 1 (8 concepts)
    (1, 0): "green, brown",
    (1, 1): "blue",
    (1, 2): "red",
    (1, 3): "gray",
    (1, 4): "ambiguous",
    (1, 5): "purple",
    (1, 6): "cyan",
    (1, 7): "yellow",
    
    # Block 2 (12 concepts)
    (2, 0): "purple, small",
    (2, 1): "purple, blue, metal",
    (2, 2): "brown, rubber",
    (2, 3): "green, rubber",
    (2, 4): "yellow, rubber",
    (2, 5): "green, gray, metal",
    (2, 6): "red, cyan, rubber, large",
    (2, 7): "brown, metal",
    (2, 8): "yellow, metal",
    (2, 9): "red, metal",
    (2, 10): "small",
    (2, 11): "ambiguous",
    
    # Block 4 (7 concepts)
    (4, 0): "large sphere",
    (4, 1): "large",
    (4, 2): "large, position",
    (4, 3): "large, position",
    (4, 4): "large",
    (4, 5): "small",
    (4, 6): "small, position",
    
    # Block 5 (15 concepts)
    (5, 0): "green",
    (5, 1): "brown",
    (5, 2): "red",
    (5, 3): "small",
    (5, 4): "yellow",
    (5, 5): "yellow",
    (5, 6): "small",
    (5, 7): "gray, metal",
    (5, 8): "yellow",
    (5, 9): "yellow",
    (5, 10): "yellow",
    (5, 11): "yellow",
    (5, 12): "yellow",
    (5, 13): "yellow, small",
    (5, 14): "yellow",
    
    # Block 6 (1 concept)
    (6, 0): "ambiguous",
    
    # Block 7 (2 concepts)
    (7, 0): "small",
    (7, 1): "ambiguous",
    
    # Block 8 (5 concepts)
    (8, 0): "small",
    (8, 1): "cube",
    (8, 2): "cylinder",
    (8, 3): "metal sphere",
    (8, 4): "rubber sphere",
    
    # Block 9 (1 concept)
    (9, 0): "ambiguous",
    
    # Block 10 (11 concepts)
    (10, 0): "small",
    (10, 1): "small",
    (10, 2): "purple, large",
    (10, 3): "gray cube",
    (10, 4): "purple, large",
    (10, 5): "purple",
    (10, 6): "purple",
    (10, 7): "gray",
    (10, 8): "gray, large",
    (10, 9): "gray, large",
    (10, 10): "gray, large",
    
    # Block 11 (1 concept)
    (11, 0): "ambiguous",
    
    # Block 12 (20 concepts)
    (12, 0): "cyan, metal",
    (12, 1): "cyan, rubber",
    (12, 2): "blue, metal",
    (12, 3): "blue, rubber",
    (12, 4): "red, rubber",
    (12, 5): "red, metal",
    (12, 6): "brown, rubber",
    (12, 7): "purple, rubber",
    (12, 8): "purple, metal",
    (12, 9): "green, metal",
    (12, 10): "yellow, metal",
    (12, 11): "gray, metal",
    (12, 12): "ambiguous",
    (12, 13): "green, rubber",
    (12, 14): "brown, metal",
    (12, 15): "brown, metal",
    (12, 16): "gray, rubber, small",
    (12, 17): "yellow, rubber",
    (12, 18): "yellow, rubber",
    (12, 19): "yellow, rubber, small",
    
    # Block 13 (5 concepts)
    (13, 0): "ambiguous",
    (13, 1): "red, metal",
    (13, 2): "red, metal",
    (13, 3): "red, small",
    (13, 4): "red, metal",
    
    # Block 14 (1 concept)
    (14, 0): "ambiguous",
    
    # Block 15 (3 concepts)
    (15, 0): "small",
    (15, 1): "large",
    (15, 2): "large",
}