import torch.nn.functional as F
import torch

def same_winner_criterion(baseline_output, new_output, delta):
    """
    Pruning criterion based on unchanged top-1 prediction.
    """
    # baseline_output, new_output: shape (B, C)
    _, pred_old = baseline_output.max(dim=1)   # shape (B,)
    _, pred_new = new_output.max(dim=1)        # shape (B,)
    # Only prune if all samples have the same top‐1
    return torch.equal(pred_new, pred_old)


def abs_mean_criterion(baseline_output, new_output, delta):
    """
    Pruning criterion based on average absolute difference.
    """
    # diff: shape (B, C)
    diff = (baseline_output - new_output).abs()           # shape (B, C)
    mean_per_sample = diff.mean(dim=1)                   # shape (B,)
    return bool((mean_per_sample < delta).all())



def abs_max_criterion(baseline_output, new_output, delta):
    """
    Pruning criterion based on maximum absolute difference.
    Returns True if the absolute difference for every entry is smaller than the threshold.
    """
    diff = (baseline_output - new_output).abs()           # shape (B, C)
    max_per_sample = diff.max(dim=1)[0]                   # shape (B,)
    return bool((max_per_sample < delta).all())



def winner_diff_criterion(fullnet_y, pruned_y, delta):
    """
    Pruning criterion based on the difference in logits of the top-1 prediction.
    """
    B = fullnet_y.size(0)
    _, pred_old = fullnet_y.max(dim=1)   # shape (B,)

    # Gather the “winner” logits for each sample
    top_old = fullnet_y[torch.arange(B), pred_old]  # (B,)
    top_new = pruned_y[torch.arange(B), pred_old]   # (B,)
    diff = (top_old - top_new).abs()                 # (B,)
    return bool((diff <= delta).all())



def winner_runner_criterion(fullnet_y, pruned_y, delta):
    top2 = fullnet_y.topk(2, dim=1)
    winner_idx = top2.indices[:, 0]  # shape (B,)
    runner_idx = top2.indices[:, 1]  # shape (B,)

    # Gather full_diff and pruned_diff
    full_flat = fullnet_y[torch.arange(fullnet_y.size(0)), winner_idx]
    full_r = fullnet_y[torch.arange(fullnet_y.size(0)), runner_idx]
    full_diff = full_flat - full_r  # (B,)

    pruned_flat = pruned_y[torch.arange(pruned_y.size(0)), winner_idx]
    pruned_r = pruned_y[torch.arange(pruned_y.size(0)), runner_idx]
    pruned_diff = pruned_flat - pruned_r  # (B,)

    # For each i: if full_diff[i] < delta → keep True; else require pruned_diff[i] ≥ delta
    cond = torch.where(full_diff < delta,
                       torch.ones_like(full_diff, dtype=torch.bool),
                       pruned_diff >= delta)  # (B,) boolean

    return bool(cond.all())



def compute_kl_divergence(distribution_g, distribution_h_new):
    """
    Parameters:
    - distribution_g (torch.Tensor): Baseline distribution G (batch_size x num_classes).
    - distribution_h_new (torch.Tensor): New distribution H_new (batch_size x num_classes).

    Returns:
    - float: The KL divergence value D_KL(G || H_new).
    """
    return F.kl_div(
        F.log_softmax(distribution_g, dim=1),
        F.softmax(distribution_h_new, dim=1),
        reduction='batchmean')