# Import Dict and Any
from typing import Dict, Any

import importlib
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import utils_output_masking as uom

from library import baseline_configs
from library import configs
from library import load_datasets
from library import metrics
from library import misc
from library import model_io
from library import models
from library import results_json
from library import train

from difflogic import GroupSum, LogicLayer


importlib.reload(load_datasets)
importlib.reload(baseline_configs)
importlib.reload(configs)



def analyze_neuron_importance(
    model: torch.nn.Module, 
    test_loader: torch.utils.data.DataLoader,
    config: configs.DifflogicConfig
) -> Dict[str, Any]:
    """
    Analyzes the importance of individual neurons in the final layer of a DiffLogic model.
    
    Args:
        model: Trained DiffLogic model
        test_loader: DataLoader with test data
        config: DiffLogic configuration
        
    Returns:
        Dictionary containing analysis results
    """
    device = config.model_config.device
    model.eval()
    
    # Extract the final logic layer before GroupSum
    final_layer = None
    for i, layer in enumerate(model):
        if isinstance(layer, models.GroupSum):
            final_layer = model[i-1]
            break
    
    assert final_layer is not None, "Could not find final logic layer"
    
    # Get number of classes and neurons per class
    num_classes = model[-1].k
    neurons_per_class = final_layer.out_dim // num_classes
    
    # Collect activations and contribution scores
    neuron_activations = torch.zeros(final_layer.out_dim, device=device)
    neuron_contributions = torch.zeros(final_layer.out_dim, device=device)
    neuron_accuracy = torch.zeros(final_layer.out_dim, device=device)
    neuron_false_positives = torch.zeros(final_layer.out_dim, device=device)
    neuron_false_negatives = torch.zeros(final_layer.out_dim, device=device)
    neuron_true_positives = torch.zeros(final_layer.out_dim, device=device)
    neuron_true_negatives = torch.zeros(final_layer.out_dim, device=device)


    correct_counts = torch.zeros(final_layer.out_dim, device=device)
    samples_per_class = torch.zeros(num_classes, device=device)
    samples_total = len(test_loader.dataset)
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            # Threshold the input to the model
            # x = (x > 0.5).float()
            
            # Get intermediate activations before GroupSum
            intermediate = model[:-1](x) > 0.5
            
            # Count activations by neuron
            neuron_activations += intermediate.sum(dim=0)
            
            # For each correctly classified sample, credit the neurons that contributed
            outputs = model(x)
            predictions = outputs.argmax(dim=1)
            correct_mask = (predictions == y)

            for i, (is_correct, true_class, activation) in enumerate(zip(correct_mask, y, intermediate)):
                # Credit neurons in the class group
                class_start = true_class * neurons_per_class
                class_end = (true_class + 1) * neurons_per_class
                if is_correct:
                    
                    # Credit neurons that were active for correct predictions
                    neuron_contributions[class_start:class_end] += activation[class_start:class_end]
                    
                    # Count correct activations per neuron
                    correct_counts[class_start:class_end] += activation[class_start:class_end]

                samples_per_class[true_class] += 1

                # Calculate metrics
                ground_truth_mask = torch.zeros(final_layer.out_dim, dtype=torch.bool, device=device)
                ground_truth_mask[class_start:class_end] = True

                neuron_accuracy += (activation == ground_truth_mask).float()
                neuron_false_positives += (activation & ~ground_truth_mask).float()
                neuron_false_negatives += (~activation & ground_truth_mask).float()
                neuron_true_positives += (activation & ground_truth_mask).float()
                neuron_true_negatives += (~activation & ~ground_truth_mask).float()

    # Normalize neuron contributions
    neuron_accuracy /= samples_total
    neuron_false_positives /= samples_total
    neuron_false_negatives /= samples_total
    neuron_true_positives /= samples_total
    neuron_true_negatives /= samples_total

    
    # Calculate per-neuron metrics
    activation_rate = neuron_activations / len(test_loader.dataset)
    correct_rate = correct_counts / samples_per_class.repeat_interleave(neurons_per_class)
    
    # Analyze neuron importance distribution
    results = {
        'activation_rate': activation_rate.cpu().numpy(),
        'correct_rate': correct_rate.cpu().numpy(),
        'contribution_scores': neuron_contributions.cpu().numpy(),
        'neurons_per_class': neurons_per_class,
        'num_classes': num_classes,
        'accuracy': neuron_accuracy.cpu().numpy(),
        'false_positives': neuron_false_positives.cpu().numpy(),
        'false_negatives': neuron_false_negatives.cpu().numpy(),
        'true_positives': neuron_true_positives.cpu().numpy(),
        'true_negatives': neuron_true_negatives.cpu().numpy(),
        'samples_per_class': samples_per_class.cpu().numpy(),
        'samples_total': samples_total,
    }
    
    # Additional analysis
    for class_idx in range(num_classes):
        class_start = class_idx * neurons_per_class
        class_end = (class_idx + 1) * neurons_per_class
        class_scores = neuron_contributions[class_start:class_end].cpu().numpy()
        
        # Sort neurons by contribution score
        sorted_indices = np.argsort(-class_scores)
        cumulative_contribution = np.cumsum(class_scores[sorted_indices]) / np.sum(class_scores)
        
        # Find what percentage of neurons account for 80% of the contribution
        threshold_80pct = np.searchsorted(cumulative_contribution, 0.8) + 1
        results[f'neurons_for_80pct_class_{class_idx}'] = threshold_80pct
        results[f'pct_neurons_for_80pct_class_{class_idx}'] = threshold_80pct / neurons_per_class * 100
    
    return results

def calculate_metrics(results):
    # f1 score
    precision = results['true_positives'] / (results['true_positives'] + results['false_positives'] + 1e-8)  # Correct denominator for precision
    recall = results['true_positives'] / (results['true_positives'] + results['false_negatives'] + 1e-8)  # Correct denominator for recall
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    results['f1_score'] = f1

    # Random
    num_neurons = results['true_positives'].shape[0]
    random_scores = np.random.rand(num_neurons)
    results['random_score'] = random_scores
    
    # Other metrics
    tp = results['true_positives']
    fp = results['false_positives']
    fn = results['false_negatives']
    tn = results['true_negatives']
    
    eps = 1e-8  # To avoid division by zero
    
    # Metrics
    precision     = tp / (tp + fp + eps)
    recall        = tp / (tp + fn + eps)  # aka sensitivity
    specificity   = tn / (tn + fp + eps)
    f1_score      = 2 * (precision * recall) / (precision + recall + eps)
    accuracy      = (tp + tn) / (tp + tn + fp + fn + eps)
    
    # MCC (Matthews correlation coefficient)
    mcc_numerator = (tp * tn) - (fp * fn)
    mcc_denominator = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) + eps
    mcc = mcc_numerator / mcc_denominator
    
    # Store everything in results
    results.update({
        'precision': precision,
        'recall': recall,
        'sensitivity': recall,  # same as recall
        'specificity': specificity,
        'f1_score': f1_score,
        'accuracy': accuracy,
        'mcc': mcc,
    })
    return results

def run(results, label, network, num_classes, test_loader, validation_loader, config):
    metric = results[label]
    test_accuracies_train = []
    test_accuracies_eval = []
    val_accuracies_eval = []
    val_accuracies_train = []
    n_values = []
    for n in range(0, 6400, 100):
    
        # Create the mask (True = keep, False = mask out)
        mask = np.ones_like(metric, dtype=bool)
        
        # Compute split sizes (handle edge case if not evenly divisible)
        splits = np.array_split(metric, num_classes)
        
        # Track where we are in the original tensor
        start = 0
        
        for part in splits:
            # Get indices of worst n F1 scores within the split
            if len(part) <= n:
                worst_indices = np.arange(len(part))  # mask all if not enough
            else:
                worst_indices = np.argsort(part)[:n]
        
            # Map back to original index space
            mask[start + worst_indices] = False
            start += len(part)
        
        # Add to results
        results['mask'] = mask
        mask = torch.tensor(mask, dtype=torch.bool, device=config.data_config.device)
        print(f"Masked: {n}")
        test_acc_eval = metrics.masked_accuracy(network, test_loader, config=config, train_mode=False, mask=mask)
        test_acc_train = metrics.masked_accuracy(network, test_loader, config=config, train_mode=True, mask=mask)
        valid_acc_eval = metrics.masked_accuracy(network, validation_loader, config=config, train_mode=False, mask=mask)
        valid_acc_train = metrics.masked_accuracy(network, validation_loader, config=config, train_mode=True, mask=mask)
        
        test_accuracies_train.append(test_acc_train)
        test_accuracies_eval.append(test_acc_eval)
        val_accuracies_eval.append(valid_acc_eval)
        val_accuracies_train.append(valid_acc_train)
        
        n_values.append(n)
    return test_accuracies_train, test_accuracies_eval, val_accuracies_eval, val_accuracies_train, n_values