import torch
from jsd import jsd_bins_hist

def cnd(loader, model, device, performances_dict, metric, args, logger):
    """
    Compute Conditional Neuron Divergence (CND) metrics for a given model and dataset loader.

    Parameters:
    - loader: DataLoader for the dataset.
    - model: PyTorch model for which CND is computed.
    - device: Device (e.g., 'cuda' or 'cpu') for computation.
    - performances_dict: Dictionary to store performance metrics.
    - args: Arguments containing configurations.
    - logger: Logger instance for reporting.

    Returns:
    - performances_dict: Updated dictionary with computed CND metrics.
    """
    model.eval()
    with torch.no_grad():
        # Compute neuron pre-activations and associated data
        pre_activations, labels, predictions = calculate_neuron_pre_activations(loader, model, device, args)
        cnd_result = jsd_bins_hist(pre_activations, labels, args)
        performances_dict[metric] = cnd_result
        logger.info(f"CND: {torch.mean(cnd_result):.4f}")
    return performances_dict, predictions


def calculate_neuron_pre_activations(loader, model, device, args, max_images=10000):
    """
    Calculate neuron pre-activations, labels, and predictions for a given loader.

    Parameters:
    - loader: DataLoader for the dataset.
    - model: PyTorch model for which pre-activations are calculated.
    - device: Device (e.g., 'cuda' or 'cpu') for computation.
    - args: Arguments containing configurations.
    - max_images: Maximum number of images to process. Defaults to 10,000.

    Returns:
    - pre_activations: Tensor containing pre-activations for all samples.
    - labels: Tensor containing labels for all samples.
    - predictions: Tensor containing predictions for all samples.
    """

    args.layer_indexes = getattr(args, "layer_indexes", [])
    args.neuron_indexes = getattr(args, "neuron_indexes", [])

    # Handle layer and neuron indexes
    layer_indexes = slice(None) if not args.layer_indexes else args.layer_indexes

    # Initialize variables
    pre_activation_list, label_list, prediction_list = [], [], []
    processed_images = 0

    for images, labels, idx in loader:
        # Stop if the maximum number of images is reached
        if processed_images >= max_images:
            break

        # Move data to device
        images, labels = images.to(device), labels.to(device)

        # Forward pass to get activations and predictions
        logit, pre_activation = model(
            images, idx
        )
        _, predictions = torch.max(logit, dim=1)

        # Store activations, labels, and predictions
        pre_activation_list.append(pre_activation.cpu())
        label_list.append(labels.cpu())
        prediction_list.append(predictions.cpu())

        # Update the count of processed images
        processed_images += images.size(0)

    # Combine results from all batches
    pre_activations = torch.cat(pre_activation_list, dim=0)
    labels = torch.cat(label_list)
    predictions = torch.cat(prediction_list)

    return pre_activations, labels, predictions


def update_neuron_activations(pre_activation, labels, predictions, neuron_activations_dict, 
                              max_activation, min_activation, args):
    """
    Update neuron activation statistics based on model predictions and labels.
    """
    for class_idx in range(args.num_classes):
        mask = (labels == class_idx) & (predictions == labels) if getattr(args,'CND_well_classified_filter', True) else (labels == class_idx)
        if mask.sum() == 0:
            continue
        
        neuron_activations = pre_activation[mask].detach().cpu()

        # Update max and min activations
        max_values = torch.max(neuron_activations, dim=0).values
        min_values = torch.min(neuron_activations, dim=0).values
        max_activation = torch.maximum(max_activation, max_values)
        min_activation = torch.minimum(min_activation, min_values)

        # Update or initialize activations per class
        if class_idx in neuron_activations_dict:
            neuron_activations_dict[class_idx] = torch.cat(
                [neuron_activations_dict[class_idx], neuron_activations], dim=0
            )
        else:
            neuron_activations_dict[class_idx] = neuron_activations





