import random
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import logging

from matplotlib.patches import FancyArrowPatch
from models.model import cifar10_small_config, gtsrb_config, taxinet_config


### HELPER TO AVOID DUPLICATION FOR WINNER–RUNNER & ABS-MAX ###



def compute_logit_metrics(
    dataset: str,
    full_model: nn.Module,
    pruned_model: nn.Module,
    X: torch.Tensor,
    device: torch.device,
) -> dict:
    """
    Compute comparison metrics between full_model(x) and pruned_model(x).

    - Classification (default): winner/runner logits and abs-max difference across classes.
    - Regression (dataset == 'taxinet'): per-sample MAE, MSE, per-dimension abs diffs, and abs-max diff.
    """
    full_model.eval()
    pruned_model.eval()
    X = X.to(device)

    with torch.no_grad():
        full_out = full_model(X)
        pruned_out = pruned_model(X)

        # Handle regression case (taxinet)
        if dataset == 'taxinet':
            # Ensure 2D shape [N, D] for consistent downstream formatting
            if full_out.dim() == 1:
                full_out = full_out.unsqueeze(1)
            if pruned_out.dim() == 1:
                pruned_out = pruned_out.unsqueeze(1)

            diffs = full_out - pruned_out
            abs_diffs = torch.abs(diffs)                  # [N, D]
            abs_max_diff, abs_max_dim = abs_diffs.max(1)  # [N], [N]
            mae = abs_diffs.mean(dim=1)                   # [N]
            mse = (diffs.pow(2)).mean(dim=1)              # [N]

            return {
                'dataset': dataset,
                'full_logits': full_out,            # keep name for compatibility
                'pruned_logits': pruned_out,        # keep name for compatibility
                'abs_diff_per_dim': abs_diffs,      # [N, D]
                'abs_max_diff': abs_max_diff,       # [N]
                'abs_max_class': abs_max_dim,       # [N] (dimension index)
                'mae': mae,                         # [N]
                'mse': mse                          # [N]
            }

        # Classification case (default)
        full_logits = full_out
        pruned_logits = pruned_out

        # Winner & runner for full model (batch version)
        top2 = torch.topk(full_logits, 2, dim=1)  # shapes: values [N,2], indices [N,2]
        winner_idx = top2.indices[:, 0]            # shape: [N]
        runner_idx = top2.indices[:, 1]            # shape: [N]
        winner_logit_full = full_logits.gather(1, winner_idx.unsqueeze(1)).squeeze(1)
        runner_logit_full = full_logits.gather(1, runner_idx.unsqueeze(1)).squeeze(1)
        winner_runner_diff_full = winner_logit_full - runner_logit_full

        # Corresponding logits in pruned model using indices from full model
        winner_logit_pruned = pruned_logits.gather(1, winner_idx.unsqueeze(1)).squeeze(1)
        runner_logit_pruned = pruned_logits.gather(1, runner_idx.unsqueeze(1)).squeeze(1)
        winner_runner_diff_pruned = winner_logit_pruned - runner_logit_pruned

        # Abs‐max difference across all classes for each input
        abs_diffs = torch.abs(full_logits - pruned_logits)
        abs_max_diff, abs_max_class = torch.max(abs_diffs, dim=1)

        # Winner diff
        winner_logits_diff = torch.abs(winner_logit_full - winner_logit_pruned)

    return {
        'dataset': dataset,
        'full_logits': full_logits,
        'pruned_logits': pruned_logits,
        'winner_idx': winner_idx,
        'runner_idx': runner_idx,
        'winner_logit_full': winner_logit_full,
        'runner_logit_full': runner_logit_full,
        'winner_runner_diff_full': winner_runner_diff_full,
        'winner_logit_pruned': winner_logit_pruned,
        'runner_logit_pruned': runner_logit_pruned,
        'winner_runner_diff_pruned': winner_runner_diff_pruned,
        'abs_max_diff': abs_max_diff,
        'abs_max_class': abs_max_class,
        'winner_logits_diff': winner_logits_diff
    }

def format_logit_comparison(metrics):
    """
    Loop over each sample in the batch and prepare a list of strings with detailed metrics,
    rounding everything to 3 decimal places.
    """
    lines = []
    batch_size = metrics['full_logits'].shape[0]

    for i in range(batch_size):
        lines.append(f"Sample {i}:")
        # round each logit via f-string
        full_vals   = metrics['full_logits'][i].tolist()
        pruned_vals = metrics['pruned_logits'][i].tolist()
        full_strs   = ", ".join(f"{v:.3f}" for v in full_vals)
        pruned_strs = ", ".join(f"{v:.3f}" for v in pruned_vals)
        lines.append(f"Full logits         : [{full_strs}]")
        lines.append(f"Pruned logits       : [{pruned_strs}]")

        lines.append(f"Winner idx          : {metrics['winner_idx'][i].item()}")
        lines.append(f"Runner idx          : {metrics['runner_idx'][i].item()}")

        # the rest stay with :.3f
        lines.append(f"Winner logit (full) : {metrics['winner_logit_full'][i].item():.3f}")
        lines.append(f"Runner logit (full) : {metrics['runner_logit_full'][i].item():.3f}")
        lines.append(f"Winner-Runner diff (full): "f"{metrics['winner_runner_diff_full'][i].item():.3f}")
        lines.append(f"Winner logit (pruned): {metrics['winner_logit_pruned'][i].item():.3f}")
        lines.append(f"Runner logit (pruned): {metrics['runner_logit_pruned'][i].item():.3f}")
        lines.append(f"Winner-Runner diff (pruned): "f"{metrics['winner_runner_diff_pruned'][i].item():.3f}")
        lines.append(f"Abs max diff           : "f"{metrics['abs_max_diff'][i].item():.3f} "f"(class {metrics['abs_max_class'][i].item()})")
        lines.append(f"Winner logit diff (abs): "f"{metrics['winner_logits_diff'][i].item():.3f}")
        lines.append("-" * 50)

    return lines

def format_regression_comparison(metrics):
    """
    Regression-only (taxinet) formatter. Expects compute_logit_metrics called with dataset='taxinet'.
    """
    lines = []
    batch_size = metrics['full_logits'].shape[0]

    for i in range(batch_size):
        lines.append(f"Sample {i}:")
        # Outputs
        full_vals = metrics['full_logits'][i].tolist()
        pruned_vals = metrics['pruned_logits'][i].tolist()
        if not isinstance(full_vals, list):
            full_vals = [full_vals]
        if not isinstance(pruned_vals, list):
            pruned_vals = [pruned_vals]
        full_strs = ", ".join(f"{v:.3f}" for v in full_vals)
        pruned_strs = ", ".join(f"{v:.3f}" for v in pruned_vals)
        lines.append(f"Full output         : [{full_strs}]")
        lines.append(f"Pruned output       : [{pruned_strs}]")

        # Regression metrics
        if 'abs_diff_per_dim' in metrics:
            per_dim = metrics['abs_diff_per_dim'][i].tolist()
            if not isinstance(per_dim, list):
                per_dim = [per_dim]
            per_dim_str = ", ".join(f"{v:.3f}" for v in per_dim)
            lines.append(f"Abs diff per dim     : [{per_dim_str}]")
        if 'mae' in metrics:
            lines.append(f"MAE                  : {metrics['mae'][i].item():.3f}")
        if 'mse' in metrics:
            lines.append(f"MSE                  : {metrics['mse'][i].item():.3f}")
        if 'abs_max_diff' in metrics and 'abs_max_class' in metrics:
            lines.append(f"Abs diff         : {metrics['abs_max_diff'][i].item():.3f} (dim {metrics['abs_max_class'][i].item()})")
        lines.append("-" * 50)

    return lines

def visualize_pruned_vision_model(model, plot_path, components, title="Pruned Net Visualization", dataset='cifar10', **kwargs):
    if components['granularity'] == 'conv_channels':
        visualize_channels_grid(model, components['active'], plot_path, title=title)
    else:  # 'conv_heads'
        visualize_pruned_vision_model_conv_heads(components['active'], model, plot_path, title, dataset=dataset)


def visualize_pruned_vision_model_conv_heads(active_heads, model, plot_path, title, dataset):
    if dataset == 'gtsrb':
        logging.info(f"Visualizing GTSRB model with conv heads granularity is not defined.")
        return
    G = nx.DiGraph()
    layer_colors = {
        'conv1': 'lightgreen',
        'layer1': 'gold',
        'layer2': 'lightblue',
        'layer3': 'steelblue'
    }
    prev_node = None
    conv_layer_names = []
    active_heads = set(active_heads)
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d) and 'downsample' not in name.lower() and 'shortcut' not in name.lower():
            conv_layer_names.append(name)
            out_channels, _, _, _ = layer.weight.shape
            pruned_filters = 0 if name in active_heads else out_channels
            inner_color = 'grey' if pruned_filters == out_channels else layer_colors.get(name.split('.')[0], 'skyblue')
            outer_color = inner_color

            G.add_node(name, label=f"{name}\nT:[{out_channels}]\nP:[{pruned_filters}]",
                       inner_color=inner_color, group=outer_color)

            if prev_node is not None:
                G.add_edge(prev_node, name)

            prev_node = name
    pos = {}
    layers_by_type = {}
    for node in conv_layer_names:
        layer_type = node.split('.')[0]
        layers_by_type.setdefault(layer_type, []).append(node)
    x_step = 3
    y_step = -1
    for x_idx, (layer_type, nodes) in enumerate(layers_by_type.items()):
        for y_idx, node in enumerate(nodes):
            pos[node] = (x_idx * x_step, y_idx * y_step)
    labels = nx.get_node_attributes(G, 'label')
    node_colors = [G.nodes[node]['inner_color'] for node in G.nodes()]
    outline_colors = [G.nodes[node]['group'] for node in G.nodes()]
    fig, ax = plt.subplots(figsize=(12, 12))
    nx.draw(G, pos, labels=labels, with_labels=True, node_size=1500, node_color=node_colors, font_size=8,
            font_color='black', font_weight='bold', edge_color='black', arrows=True, ax=ax,
            edgecolors=outline_colors, linewidths=4, width=1, arrowsize=15)
    # Draw residual connections
    for i in range(0, len(conv_layer_names) - 3, 3):
        start, end = conv_layer_names[i], conv_layer_names[i + 3]
        rad = 0.00 if start.endswith("2.conv2") else 0.5
        arrow = FancyArrowPatch(pos[start], pos[end], connectionstyle=f"arc3,rad={rad}",
                                arrowstyle='-|>', mutation_scale=20, color='blue', linestyle='dashed', linewidth=1,
                                transform=ax.transData)
        ax.add_patch(arrow)
    legend_handles = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=10, label='conv1'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gold', markersize=10, label='layer1'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue', markersize=10, label='layer2'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='steelblue', markersize=10, label='layer3'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='grey', markersize=10, label='Pruned Filter')
    ]
    ax.legend(handles=legend_handles, loc='center left', bbox_to_anchor=(1.05, 0.5))
    ax.set_title(title)
    ax.axis('off')
    plt.savefig(plot_path, bbox_inches='tight')


def visualize_channels_grid(model, active_channels, plot_path,
                            title="Channels Visualization"):
    # 1) collect conv2d layers in definition order
    conv_layers = [
        (name, layer) for name, layer in model.named_modules()
        if isinstance(layer, nn.Conv2d)
           and 'downsample' not in name.lower()
           and 'shortcut'  not in name.lower()
    ]

    # 2) build groups: conv1 single-column, others split into 2
    groups = []
    for idx, (name, layer) in enumerate(conv_layers):
        C = layer.weight.size(0)
        splits = 1 if idx == 0 else 2
        rows = int(np.ceil(C / splits))
        groups.append((name, layer, C, rows, splits))

    total_cols = sum(s for *_, s in groups)
    max_rows   = max(r for *_, r, _ in groups)

    # 3) create figure — narrow columns but extra horizontal gap
    fig, axs = plt.subplots(max_rows, total_cols,
                            figsize=(total_cols * 0.8, max_rows * 0.8),
                            squeeze=False)
    # increase wspace to separate conv groups further
    plt.subplots_adjust(top=0.90, wspace=3.0, hspace=0.1)

    active_set   = set(active_channels)
    layer_colors = {'conv1':'lightgreen','layer1':'gold',
                    'layer2':'lightblue','layer3':'steelblue'}

    # 4) plot each filter
    col_offset = 0
    for name, layer, C, rows, splits in groups:
        for split in range(splits):
            col = col_offset + split
            for r in range(max_rows):
                ax = axs[r, col]
                ax.axis('off')
                filt = split * rows + r
                if filt >= C:
                    continue

                w = layer.weight.data.cpu()[filt].numpy()
                w = (w - w.min()) / (w.max() - w.min() + 1e-8)
                img = np.transpose(w, (1,2,0)) if w.shape[0]==3 else w[0]

                if (name, filt) not in active_set:
                    img = np.ones_like(img) * 0.5
                    cmap, vmin, vmax = 'gray', 0, 1
                    tc = 'grey'
                else:
                    cmap, vmin, vmax = 'viridis', None, None
                    tc = layer_colors.get(name.split('.')[0], 'grey')

                ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax,
                          interpolation='nearest')
                ax.set_title(f"Ch {filt}", fontsize=6, color=tc)
        col_offset += splits

    # 5) header above each group
    header_y = 0.94
    col_offset = 0
    for name, _, _, _, splits in groups:
        c0, c1 = col_offset, col_offset + splits - 1
        x0 = axs[0, c0].get_position().x0
        x1 = axs[0, c1].get_position().x1
        xc = (x0 + x1) / 2
        fig.text(xc, header_y, name, ha='center', va='bottom',
                 fontsize=9, weight='bold')
        col_offset += splits

    # 6) move overall title above headers
    plt.suptitle(title, y=1.02, fontsize=12)
    plt.savefig(plot_path, bbox_inches='tight')
    plt.close()


def vision_circuit_stats(dataset, full_net, pruned_net, X, device, save_to_path, components, timeouts=None, sample_ids=None, **kwargs):
    lines = []
    # print sample numbers if provided
    if sample_ids is not None:
        lines.append(f"Sample Indexes: {sample_ids}")
    lines.append("\nPruning Statistics")
    lines.append("------------------")

    if dataset.startswith('cifar10'):
        total_channels = cifar10_small_config['total_channels']
    elif dataset == 'taxinet':
        total_channels = taxinet_config['total_channels']
    elif dataset == 'gtsrb':
        total_channels = gtsrb_config['total_channels']
    else:
        total_channels = "N/A"
    lines.append(f"Total Channels: {total_channels}")


    if components['granularity'] == 'conv_channels':
        lines.append(f"# Active Channels {len(components['active'])}: {components['active']}")
        lines.append(f"# Dead Channels {len(components['dead'])}:{components['dead']}")

    else: # active_components['granularity'] == 'conv_heads'
        lines.append(f"# Active Conv Heads {len(components['active'])}: {components['active']}")
        lines.append(f"# Dead Conv Heads {len(components['dead'])}: {components['dead']}")

    metrics = compute_logit_metrics(dataset=dataset, full_model=full_net, pruned_model=pruned_net, X=X, device=device)
    # Use regression formatter only for taxinet
    lines.extend(format_regression_comparison(metrics) if dataset == 'taxinet' else format_logit_comparison(metrics))

    lines.append("-" * 50 + " CLASS_PREDICTION")
    lines.append(f"Original Network Prediction: "f"{MlpCircuitStats.predictor(full_net, X, device)}")
    lines.append(f"Pruned Network Prediction: "f"{MlpCircuitStats.predictor(pruned_net, X, device)}")
    if timeouts is not None: lines.append(f"# of Timeouts: {timeouts} | count={len(timeouts)}")

    emit_to_file(lines, save_to_path)
#### MNIST SPECIFIC FUNCTIONS ####


def visualize_all_neuron_weights_as_images(
        model,
        active_neurons_only=True,
        img_dim=(28, 28),
        figsize=(12, 12)
):
    """
    Visualize weights of active neurons across all layers as 28x28 images.

    Args:
        model: PyTorch model (with nn.Linear layers).
        active_neurons_only: Whether to include only active neurons.
        img_dim: Tuple specifying the image dimensions (default is 28x28 for MNIST).
        figsize: Size of the grid figure for each layer.

    Returns:
        None. Displays grids of images for all layers.
    """
    layers = [layer for layer in model.modules() if isinstance(layer, nn.Linear)]

    if not layers:
        logging.warning("no nn.Linear layers found in the model.")
        return

    for layer_idx, layer in enumerate(layers):
        weights = layer.weight.data  # Shape: (out_features, in_features)
        num_neurons = weights.shape[0]

        # Check if weights can be reshaped to the desired img_dim
        if weights.shape[1] != img_dim[0] * img_dim[1]:
            logging.warning(f"skipping layer {layer_idx}: weights cannot be reshaped to {img_dim}.")
            continue

        # Filter active neurons if required
        if active_neurons_only:
            weight_sums = weights.abs().sum(dim=1)  # Sum of absolute weights per neuron
            alive_mask = weight_sums > 1e-8
            active_weights = weights[alive_mask]  # Keep only active neurons
        else:
            active_weights = weights

        num_active_neurons = active_weights.shape[0]

        if num_active_neurons == 0:
            logging.warning(f"no active neurons found in layer {layer_idx}. skipping visualization.")
            continue

        # Normalize weights to [0, 1]
        min_weight = active_weights.min(dim=1, keepdim=True)[0]
        max_weight = active_weights.max(dim=1, keepdim=True)[0]
        scaled_weights = (active_weights - min_weight) / (
                    max_weight - min_weight + 1e-8)  # Add epsilon to avoid divide-by-zero

        # Reshape weights to the image dimensions
        reshaped_weights = scaled_weights.view(-1, *img_dim).cpu().numpy()

        # Plot the weights
        cols = int(np.ceil(np.sqrt(num_active_neurons)))
        rows = int(np.ceil(num_active_neurons / cols))

        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        axes = axes.flatten()

        for i, ax in enumerate(axes):
            if i < num_active_neurons:
                ax.imshow(reshaped_weights[i], cmap='gray')
                ax.set_title(f"Layer {layer_idx} Neuron {i}", fontsize=8)
            else:
                ax.axis('off')  # Hide unused axes
            ax.axis('off')  # Turn off grid lines

        plt.suptitle(f"Layer {layer_idx} Visualization ({num_active_neurons} Neurons)", fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show(block=False)


def visualize_saliency_in_input_space(pruned_net, x, device, active_neurons_only=True, img_dim=(28, 28), figsize=(12, 12)):
    """
    Compute and visualize saliency maps for neurons in later layers, mapped back to the input space.

    Args:
        pruned_net: Pruned PyTorch model.
        x: Input image tensor (shape: [1, input_dim]).
        active_neurons_only: Whether to compute saliency only for active neurons.
        img_dim: Image dimensions (default is 28x28 for MNIST).
        figsize: Size of the grid figure.

    Returns:
        None. Displays grids of saliency maps for all active neurons in later layers.
    """
    pruned_net.eval()
    x = x.to(device).requires_grad_()
    layers = [layer for layer in pruned_net.modules() if isinstance(layer, nn.Linear)]

    for layer_idx, layer in enumerate(layers):
        weights = layer.weight.data
        num_neurons = weights.shape[0]

        if active_neurons_only:
            weight_sums = weights.abs().sum(dim=1)
            alive_mask = weight_sums > 1e-8
            active_neurons = torch.nonzero(alive_mask).squeeze().tolist()
        else:
            active_neurons = list(range(num_neurons))

        if not active_neurons:
            logging.warning(f"no active neurons found in layer {layer_idx}. skipping visualization.")
            continue

        saliency_maps = []
        for neuron_idx in active_neurons:
            # Forward pass
            output = pruned_net(x)  # Full forward pass
            neuron_output = output[0, neuron_idx] if layer_idx == len(layers) - 1 else layer(x)

            # Backpropagate gradients to the input
            pruned_net.zero_grad()
            neuron_output.backward(retain_graph=True)

            # Map saliency to input space
            saliency = x.grad.data.abs().view(*img_dim).cpu().numpy()
            saliency_maps.append((neuron_idx, saliency))
            x.grad.zero_()  # Clear gradients for next neuron

        # Plot saliency maps
        cols = int(len(saliency_maps) ** 0.5) + 1
        rows = (len(saliency_maps) + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, figsize=figsize)

        for ax, (neuron_idx, saliency) in zip(axes.flat, saliency_maps):
            ax.imshow(saliency, cmap='hot')  # Use red colormap
            ax.set_title(f"Neuron {neuron_idx}", fontsize=8)
            ax.axis('off')

        for ax in axes.flat[len(saliency_maps):]:  # Turn off unused axes
            ax.axis('off')

        plt.suptitle(f"Input-Space Saliency Maps for Layer {layer_idx}", fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show(block=False)

def visualize_mlp_with_active_neurons(
    model,
    plot_path,
    components,
    input_edge_fraction=0.07,
    edge_keep_fraction=1.0,
    seed=42,
    plot_height=10.0,
    **kwargs
):
    """
    Visualizes an MLP model marking active neurons and removing incoming edges to inactive neurons.

    Args:
        model: PyTorch model (nn.Linear layers).
        active_components: List of tuples like [('fc2', 8), ('fc1', 9), ...] marking active neurons.
        input_edge_fraction: Fraction of input-to-first-layer connections to keep.
        edge_keep_fraction: Fraction of edges (weights) to randomly select per layer.
        seed: Random seed for reproducibility.
        plot_height: Total height for vertical spacing.
    """
    random.seed(seed)
    torch.manual_seed(seed)

    active_neurons_set = set(components['active'])
    layer_names = [name for name, module in model.named_modules() if isinstance(module, nn.Linear)]

    G = nx.DiGraph()
    layer_nodes = {}
    prev_nodes = []

    # Handle input layer separately
    input_layer_dim = model.__dict__['_modules'][layer_names[0]].in_features
    input_nodes = [f"input_{i}" for i in range(input_layer_dim)]
    for input_node in input_nodes:
        G.add_node(input_node, active=True)
    layer_nodes["input"] = input_nodes
    prev_nodes = input_nodes

    for idx, layer_name in enumerate(layer_names):
        layer = dict(model.named_modules())[layer_name]
        nodes = [f"{layer_name}_{i}" for i in range(layer.out_features)]
        layer_nodes[layer_name] = nodes

        for node_idx, node in enumerate(nodes):
            neuron_key = (layer_name, node_idx)
            is_active = neuron_key in active_neurons_set

            G.add_node(node, active=is_active)

        weight_matrix = layer.weight.data
        edges = []
        for prev_idx, prev_node in enumerate(prev_nodes):
            # only consider edges from active source neurons
            if not G.nodes[prev_node]["active"]:
                continue

            for node_idx, node in enumerate(nodes):
                is_active = (layer_name, node_idx) in active_neurons_set
                if is_active:
                    weight_val = weight_matrix[node_idx, prev_idx].item()
                    edges.append((prev_node, node, weight_val))

        edge_keep_factor = input_edge_fraction if idx == 0 else edge_keep_fraction
        kept_edges = random.sample(edges, min(len(edges), max(1, int(len(edges) * edge_keep_factor))))
        for u, v, w in kept_edges:
            G.add_edge(u, v, weight=w)

        prev_nodes = nodes

    # Set positions
    pos = {}
    all_layers = [layer_nodes["input"]] + [layer_nodes[name] for name in layer_names]
    for layer_idx, nodes_in_layer in enumerate(all_layers):
        x_coord = layer_idx
        num_nodes = len(nodes_in_layer)
        spacing = plot_height / max(num_nodes - 1, 1)
        for node_idx, node_name in enumerate(nodes_in_layer):
            y_coord = node_idx * spacing
            pos[node_name] = (x_coord, y_coord)

    # Draw
    plt.figure(figsize=(13, 6))

    node_colors = ["skyblue" if G.nodes[n]["active"] else "lightgrey" for n in G.nodes()]
    edge_colors = ["grey" for _ in G.edges()]

    nx.draw(G, pos, with_labels=True, node_size=600, node_color=node_colors,
            edge_color=edge_colors, arrows=False)

    edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

    # Layer labels
    for idx, layer_nodes_group in enumerate(all_layers):
        label = "Input Layer" if idx == 0 else f"{layer_names[idx-1]}"
        plt.text(
            idx, plot_height + 0.5,
            f"{label}\n{len(layer_nodes_group)} neurons",
            ha="center", fontsize=12,
            bbox=dict(facecolor="white", alpha=0.8, edgecolor="black")
        )

    plt.axis('off')
    plt.savefig(plot_path, bbox_inches='tight')
    plt.close()


def emit_to_file(lines, save_to_path=None):
    logging.info("\n".join(lines))
    if save_to_path:
        with open(save_to_path, "a") as f:
            f.write("\n".join(lines) + "\n")


class MlpCircuitStats:

    @staticmethod
    def predictor(model, X, device):
        """
        Returns a list of predictions (one per sample in batch x).
        """
        with torch.no_grad():
            output = model(X.to(device))
            return torch.argmax(output, dim=1).tolist()

    @staticmethod
    def count_non_zero_params(net):
        """
        Count total non-zero parameters in the network.
        """
        return sum(torch.count_nonzero(p) for p in net.parameters())

    @staticmethod
    def print_stats(dataset, orig_net, pruned_net, X, device,
                    save_to_path=None, components=None, timeouts=None, sample_ids=None):
        """
        Print statistics and include the sample numbers if provided.
        """

        active_components = components['active']
        dead_components = components['dead']
        lines = []
        if sample_ids is not None:
            lines.append(f"Sample Indexes: {sample_ids}")

        # Active/dead neurons if needed
        lines.append(f"# Active neurons: {len(active_components)} -> {active_components}")
        lines.append(f"# Dead neurons: {len(dead_components)} -> {dead_components}")

        metrics = compute_logit_metrics(dataset=dataset, full_model=orig_net, pruned_model=pruned_net, X=X, device=device)
        lines += (format_regression_comparison(metrics) if dataset == 'taxinet' else format_logit_comparison(metrics))
        lines.append("-" * 50 + " CLASS_PREDICTION")
        lines.append(f"Original Prediction: {MlpCircuitStats.predictor(orig_net, X, device)}")
        lines.append(f"Pruned Prediction: {MlpCircuitStats.predictor(pruned_net, X, device)}")
        # Timeouts if any
        if timeouts is not None:
            lines.append(f"# Timeouts: {timeouts} | count={len(timeouts)}")
        emit_to_file(lines, save_to_path)

def print_logits_for_adv_x(dataset, pruned_model, full_model, x, adv_x, device, save_to_path=None):
    metrics = compute_logit_metrics(dataset=dataset, full_model=full_model, pruned_model=pruned_model, X=adv_x, device=device)
    # reuse logit comparison header
    lines = ["=" * 40 + " ADVERSARIAL LOGIT COMPARISON " + "=" * 40]
    # Use regression formatter only for taxinet
    lines.extend(format_regression_comparison(metrics) if dataset == 'taxinet' else format_logit_comparison(metrics))
    # distances
    diff = adv_x.to(device) - x.to(device)
    lines.append("-" * 40 + " DISTANCES " + "-" * 40)
    lines.append(f"Euclidean (L2): {torch.norm(diff).item():.6f}")
    lines.append(f"Manhattan (L1): {torch.sum(torch.abs(diff)).item():.6f}")
    lines.append(f"Chebyshev (L∞): {torch.max(torch.abs(diff)).item():.6f}")
    emit_to_file(lines, save_to_path)