import numpy as np
from itertools import product
import torch
import torch.nn.functional as F
import numpy as np

from src.iees_utils import generate_pfams, calculate_consistency_index, compute_iees_score


def evaluate_iees(model, dataloader, device, mode, alpha, beta, gamma, *thresholds):
    """
    Evaluate model performance using IEES-based early-exit policy.

    Parameters:
    - model: early-exit DNN with .forward_to_exit(x, i)
    - dataloader: evaluation dataloader
    - device: torch device
    - mode: 'shared' or 'per_exit'
    - alpha, beta, gamma: weights for confidence, attribution, consistency
    - *thresholds: threshold values (1 for shared, N for per-exit)

    Returns:
    - accuracy: early-exit model accuracy under given configuration
    """
    model.eval()
    correct = 0
    total = 0

    num_exits = model.num_exits
    pfams_list = []

    for x, target in dataloader:
        x = x.to(device)
        target = target.to(device)
        exited = False
        pfams_list = []  # Reset PFAMs for each sample

        for i in range(num_exits):
            logits, class_idx, target_score = model.forward_to_exit(x, i,True)

            # Generate PFAM + consistency index
            pfams_list, cumulative_map = generate_pfams(model.activations, model.gradients, x, pfams_list)
            consistency_index = calculate_consistency_index([cumulative_map])
            # Compute IEES score
            weights = [alpha, beta, gamma]
            iees_score, confidence, A_iees, activation_score, gradient_score, normalized_progressive_score = compute_iees_score(
                model.activations, model.gradients, logits, consistency_index, weights
            )

            # Decide threshold
            if mode == "shared":
                threshold = thresholds[0]
            else:
                threshold = thresholds[i]

            # Decide whether to exit
            if iees_score >= threshold:
                pred = logits.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += 1
                exited = True
                break

        # Fallback: final exit
        if not exited:
            logits, class_idx, target_score = model.forward_to_exit(x, num_exits - 1)
            pred = logits.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += 1

    return correct / total if total > 0 else 0.0

def grid_search_iees(model, testloader, device, calibration_mode="shared"):
    """
    Grid search to find optimal IEES weights and threshold(s).

    Parameters:
    - model: Trained model
    - testloader: DataLoader for evaluation
    - device: torch.device
    - calibration_mode: "shared" or "per_exit"

    Returns:
    - best_config: dict containing best weights and thresholds
    """
    # Search grid: weights in [0.0, 0.5, 1.0] → filtered later to enforce α + β + γ = 1
    raw_vals = [0.0, 0.5, 1.0]
    threshold_vals = np.linspace(0.1, 0.95, 9)

    best_acc = 0.0
    best_config = {}

    weight_triplets = [
        (α, β, γ) for α, β, γ in product(raw_vals, repeat=3)
        if abs(α + β + γ - 1.0) < 1e-5
    ]

    if calibration_mode == "shared":
        for (α, β, γ), τ in product(weight_triplets, threshold_vals):
            acc = evaluate_iees(model, testloader, device, calibration_mode, α, β, γ, τ)
            if acc > best_acc:
                best_acc = acc
                best_config = {"alpha": α, "beta": β, "gamma": γ, "threshold0": τ}

    elif calibration_mode == "per_exit":
        for (α, β, γ), τ0, τ1, τ2 in product(weight_triplets, threshold_vals, threshold_vals, threshold_vals):
            acc = evaluate_iees(model, testloader, device, calibration_mode, α, β, γ, τ0, τ1, τ2)
            if acc > best_acc:
                best_acc = acc
                best_config = {
                    "alpha": α, "beta": β, "gamma": γ,
                    "threshold0": τ0, "threshold1": τ1, "threshold2": τ2
                }

    print("\n[GRID RESULT] Best Accuracy: {:.4f}".format(best_acc))
    print("[GRID RESULT] Best Config:", best_config)
    return best_config


import torch
import torch.nn.functional as F


def evaluate_confidence(model, dataloader, device, mode="shared", *thresholds):
    """
    Evaluate early-exit performance using confidence-based exit policy.

    Parameters:
    - model: early-exit DNN
    - dataloader: data loader for evaluation
    - device: torch device (e.g., "cuda")
    - mode: 'shared' or 'per_exit' threshold strategy
    - *thresholds: threshold(s) for exit decision

    Returns:
    - accuracy: total accuracy with early exit strategy
    """
    model.eval()
    correct = 0
    total = 0

    num_exits = model.num_exits

    for x, target in dataloader:
        x = x.to(device)
        target = target.to(device)

        exited = False

        for i in range(num_exits):
            logits, class_idx, target_score = model.forward_to_exit(x, i, False)

            # Compute softmax confidence
            softmax_output = F.softmax(logits, dim=1)
            confidence, prediction = softmax_output.max(dim=1)

            # Choose threshold
            if mode == "shared":
                threshold = thresholds[0]
            else:
                threshold = thresholds[i]

            # Early exit condition
            if confidence.item() >= threshold:
                correct += (prediction == target).sum().item()
                total += 1
                exited = True
                break

        # Fallback: use final exit
        if not exited:
            logits, class_idx, target_score = model.forward_to_exit(x,num_exits - 1)
            prediction = logits.argmax(dim=1)
            correct += (prediction == target).sum().item()
            total += 1

    return correct / total if total > 0 else 0.0


def grid_search_confidence(model, testloader, device, calibration_mode="shared"):
    """
    Grid search for optimizing confidence-based exit thresholds.
    """
    threshold_vals = np.linspace(0.1, 0.95, 9)
    best_acc = 0.0
    best_config = {}

    if calibration_mode == "shared":
        for τ in threshold_vals:
            acc = evaluate_confidence(model, testloader, device, calibration_mode, τ)
            if acc > best_acc:
                best_acc = acc
                best_config = {"threshold0": τ}

    elif calibration_mode == "per_exit":
        for τ0, τ1, τ2 in product(threshold_vals, threshold_vals, threshold_vals):
            acc = evaluate_confidence(model, testloader, device, calibration_mode, τ0, τ1, τ2)
            if acc > best_acc:
                best_acc = acc
                best_config = {
                    "threshold0": τ0,
                    "threshold1": τ1,
                    "threshold2": τ2
                }

    print("\n[GRID RESULT] Best Accuracy: {:.4f}".format(best_acc))
    print("[GRID RESULT] Best Config:", best_config)
    return best_config



# You can now use best_config["alpha"], best_config["threshold0"], etc.
