from typing import Tuple, List

import torch


def differential_shannon_entropy_from_hist(
    pmf: torch.Tensor,
    bin_width: torch.Tensor,
    eps: float = 1e-10,
) -> torch.Tensor:
    """
    Differential Shannon entropy (histogram approximation) per neuron.

    Args
    ----
    pmf:
        Tensor of shape (n_neurons, n_bins). Rows sum to 1.
    bin_width:
        Tensor of shape (n_neurons,) with the scalar bin width for each neuron.
    eps:
        Small constant to avoid log(0).

    Returns
    -------
    torch.Tensor
        Tensor of shape (n_neurons,) with the entropy for each neuron.
    """
    pmf = pmf.clamp_min(eps)  # avoid log(0)
    # log2(p / Δx): Δx is per-neuron, so broadcast along bins
    log_term = torch.log2(pmf / bin_width.unsqueeze(1).clamp_min(eps))
    return -(pmf * log_term).sum(dim=1)


def compute_column_pmf_vect(
    tensor: torch.Tensor,
    max_value: torch.Tensor,
    min_value: torch.Tensor,
    num_bins: int = 40,
    min_threshold: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute per-column PMFs via a vectorized histogram.

    Args
    ----
    tensor:
        Input tensor of shape (num_samples, n_neurons).
    max_value:
        Tensor of shape (n_neurons,) with per-neuron maxima (over the FULL dataset).
    min_value:
        Tensor of shape (n_neurons,) with per-neuron minima (over the FULL dataset).
    num_bins:
        Number of histogram bins.
    min_threshold:
        Minimum count per bin before normalization (clipped).

    Returns
    -------
    pmf:
        Tensor of shape (n_neurons, n_bins) with per-neuron PMFs.
    bin_edges:
        Tensor of shape (n_neurons, n_bins + 1) with per-neuron bin edges.
    bin_centers:
        Tensor of shape (n_neurons, n_bins) with per-neuron bin centers.
    """
    device = tensor.device
    n_neurons = tensor.size(1)

    # Per-neuron edges and centers
    bin_width = (max_value - min_value) / num_bins  # (n_neurons,)
    # Handle degenerate ranges gracefully (width==0 → keep a tiny width)
    safe_bin_width = torch.where(bin_width == 0, torch.ones_like(bin_width), bin_width)

    base = torch.linspace(0, num_bins, steps=num_bins + 1, device=device).view(1, -1)
    bin_edges = base * safe_bin_width.view(-1, 1) + min_value.view(-1, 1)  # (n_neurons, n_bins+1)
    bin_centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])              # (n_neurons, n_bins)

    # Vectorized binning
    activations = tensor.T                                            # (n_neurons, num_samples)
    lower = bin_edges[:, :-1]                                         # (n_neurons, n_bins)
    upper = bin_edges[:, 1:]                                          # (n_neurons, n_bins)

    # Mask: (n_neurons, num_samples, n_bins)
    in_bin = (activations.unsqueeze(2) >= lower.unsqueeze(1)) & (activations.unsqueeze(2) < upper.unsqueeze(1))

    # Counts per bin → (n_neurons, n_bins)
    counts = in_bin.sum(dim=1).float()

    # Threshold then normalize to PMF
    counts = counts.clamp_min(min_threshold)
    row_sums = counts.sum(dim=1, keepdim=True).clamp_min(1.0)  # avoid division by zero
    pmf = counts / row_sums

    return pmf, bin_edges, bin_centers


def jsd_bins_hist(
    pre_activations: torch.Tensor,
    labels: torch.Tensor,
    args,
) -> torch.Tensor:
    """
    Jensen–Shannon divergence (histogram-based) per neuron across classes.

    Args
    ----
    pre_activations:
        Tensor of shape (N, n_neurons) with activations.
    labels:
        Tensor of shape (N,) with class labels in {0, ..., args.num_classes-1}.
    args:
        Object with attributes:
          - num_classes: int
          - JSD_bins (optional): int, default 40

    Returns
    -------
    torch.Tensor
        Tensor of shape (n_neurons,) with JSD estimates per neuron.
    """
    n_neurons = pre_activations.shape[1]
    num_bins = getattr(args, "JSD_bins", 40)

    # Global min/max for consistent binning across classes
    max_activation, _ = pre_activations.max(dim=0)  # (n_neurons,)
    min_activation, _ = pre_activations.min(dim=0)  # (n_neurons,)
    bin_width = (max_activation - min_activation) / num_bins  # (n_neurons,)

    jsd = torch.zeros(n_neurons, device=pre_activations.device)

    pmfs_per_class: List[torch.Tensor] = []
    centers_per_class: List[torch.Tensor] = []

    # Average of class entropies
    class_entropy_sum = torch.zeros(n_neurons, device=pre_activations.device)
    class_count = 0

    for cl in range(args.num_classes):
        mask = (labels == cl)
        if mask.any():
            sample = pre_activations[mask]  # (n_c, n_neurons)
            pmf_c, edges_c, centers_c = compute_column_pmf_vect(
                tensor=sample,
                max_value=max_activation,
                min_value=min_activation,
                num_bins=num_bins,
                min_threshold=0.0,
            )
            pmfs_per_class.append(pmf_c)
            centers_per_class.append(centers_c)

            class_entropy_sum += differential_shannon_entropy_from_hist(pmf_c, bin_width)
            class_count += 1

    if class_count > 0:
        jsd -= class_entropy_sum / class_count

    # Mean PMF across classes
    if pmfs_per_class:
        pmf_mean = torch.mean(torch.stack(pmfs_per_class, dim=0), dim=0)  # (n_neurons, n_bins)
        jsd += differential_shannon_entropy_from_hist(pmf_mean, bin_width)

    # If a neuron has zero range (bin_width==0), define JSD=0
    zero_width = (bin_width == 0)
    if zero_width.any():
        jsd[zero_width] = 0.0

    # Optional NaN report (kept from original)
    if torch.isnan(jsd).any():
        nan_idx = torch.isnan(jsd).nonzero(as_tuple=True)[0]
        print("Indices of NaN values:", nan_idx)

    return jsd
