import numpy as np
import torch
import torch.nn.functional as F

def generate_gradcam(activations, gradients, input_tensor):
    """
    Grad-CAM implementation for a given exit.
    """
    weights = gradients.mean(dim=(2, 3), keepdim=True)
    cam = torch.relu((weights * activations).sum(dim=1))
    cam = F.interpolate(cam.unsqueeze(1), size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
    cam = cam.detach().squeeze().cpu().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-10)
    return cam


def generate_pfams(activations, gradients, input_tensor, pfams_list):
    """
    Generate and accumulate PFAM (Progressive Feature Attribution Maps).
    """


    cam = generate_gradcam(activations, gradients, input_tensor)
    pfams_list.append(cam)



    # Cumulative average (paper-aligned)
    cumulative_map = np.mean(np.stack(pfams_list), axis=0)

    return pfams_list, cumulative_map

def normalize_tensor(tensor, eps=1e-6):
    """
    Normalize a tensor to have zero mean and unit variance.

    Parameters:
    - tensor (torch.Tensor): Input tensor
    - eps (float): Small constant to avoid division by zero

    Returns:
    - torch.Tensor: Normalized tensor
    """
    mean = tensor.mean()
    std = tensor.std()
    return (tensor - mean) / (std + eps)

def calculate_consistency_index(cumulative_maps):
    """
    Compute the consistency index Q_i(x) between the last two cumulative PFAM maps.

    This index quantifies the semantic stability of attention across successive exits.
    A lower value indicates more consistent and convergent model focus.

    Parameters:
    - cumulative_maps (list of np.ndarray): List of cumulative attribution maps Ci(x)
      in sequential order of exits.

    Returns:
    - float: Consistency index Q_i(x). Returns 0.0 if not enough maps.
    """
    if len(cumulative_maps) < 2:
        return 0.0  # No consistency signal if only one exit

    Ci = cumulative_maps[-1]
    Ci_prev = cumulative_maps[-2]

    # Mean absolute difference across all spatial locations
    consistency_index = np.mean(np.abs(Ci - Ci_prev))
    return consistency_index
def compute_iees_score(activations, gradients, output, progressive_score,w=None):
    #print("w==============", progressive_score, w)
    """
    Compute the Interpretability-Based Early-Exit Score (IEES) for an exit.

    Parameters:
    - activations (torch.Tensor): Activations from the exit layer.
    - gradients (torch.Tensor): Gradients from the exit layer.
    - output (torch.Tensor): Model output at the exit.
    - progressive_score (float or torch.Tensor): Progressive score for the model's progression.
    - w (list of floats): Weights for A_iees, C_iees, and progressive score.

    Returns:
    - Tuple containing iees_score, C_iees, A_iees, activation_score, gradient_score, and normalized_progressive_score.
    """

    # Normalize activations and gradients
    normalized_activations = normalize_tensor(activations)
    normalized_gradients = normalize_tensor(gradients)

    # Normalize the progressive score if it is a tensor
    normalized_progressive_score = (
        normalize_tensor(progressive_score) if isinstance(progressive_score, torch.Tensor) else progressive_score
    )

    # Attribution-based Component (A_iees)
    attribution_map = normalized_activations * normalized_gradients
    A_iees = attribution_map.abs().mean().item()

    # Confidence-based Component (C_iees)
    confidence_scores = F.softmax(output, dim=1)
    C_iees, _ = confidence_scores.max(dim=1)
    C_iees = C_iees.item()

    # Combined IEES score
    iees_score = w[0] * A_iees + w[1] * C_iees + w[2] * normalized_progressive_score

    # Additional scores for interpretability analysis
    activation_score = torch.mean(activations).item()
    gradient_score = torch.mean(torch.abs(gradients)).item()

    #print(f"IEEScore Breakdown - A_iees: {A_iees}, C_iees: {C_iees}, Prog_Score: {normalized_progressive_score}, Final IEEScore: {iees_score}")

    return iees_score, C_iees, A_iees, activation_score, gradient_score, normalized_progressive_score

def compute_iees_score1(activations, gradients, output, consistency_index,w):

    #w = [0.9, 0.25, 0.15]
    #print("w============1111==",  consistency_index,w)
    """
    Compute IEES score for an exit.
    Parameters:
        activations: Feature maps at exit
        gradients: Gradients for target class
        output: Logits at exit
        consistency_index: Consistency index Qi(x)
        w: list of weights [w1, w2, w3]
    Returns:
        Tuple: iees_score, S_iees, F_iees, activation_score, gradient_score, Q_iees
    """
    # Normalize activations and gradients
    norm_act = normalize_tensor(activations)
    norm_grad = normalize_tensor(gradients)

    # Attribution strength (F_iees)
    attribution_map = norm_act * norm_grad
    F_iees = attribution_map.abs().mean().item()

    # Confidence (S_iees)
    S_iees = F.softmax(output, dim=1).max(dim=1).values.item()

    # Consistency index (Q_iees): Lower is better
    Q_iees = (
        normalize_tensor(consistency_index) if isinstance(consistency_index, torch.Tensor) else consistency_index
    )

    # IEES score (higher is better)
    iees_score = w[0] * S_iees + w[1] * F_iees + w[2] * Q_iees

    # Additional scores for interpretability
    activation_score = torch.mean(activations).item()
    gradient_score = torch.mean(torch.abs(gradients)).item()

    return iees_score, S_iees, F_iees, activation_score, gradient_score, Q_iees
