import torch
import numpy as np
from sklearn.metrics import classification_report

def accuracy_evaluation(model, dataloader, classes, is_attacked=False, method=None, epsilon=8 / 255, alpha=16 / 255, detailed=True, device:str="cuda"):
     """
    Evaluate classification accuracy on clean or adversarial examples.

    This utility function computes classification performance either on clean
    inputs or on adversarially perturbed inputs generated by a specified attack
    method. The adversarial method must follow the calling signature:
    `method(images, labels, model, epsilon, alpha)` and return perturbed images.

    The function can return overall accuracy or print a detailed class-wise
    classification report using `sklearn.metrics.classification_report`.

    Args:
        model (torch.nn.Module): The model to evaluate.
        dataloader (torch.utils.data.DataLoader): Data loader providing (images, labels).
        classes (list): List mapping class indices to human-readable labels.
        is_attacked (bool): If True, evaluates on adversarial examples; otherwise on clean data.
        method (callable): Adversarial attack function with the specified signature.
        epsilon (float): Maximum perturbation magnitude for the attack (default: 8/255).
        alpha (float): Step size for the attack (default: 16/255).
        detailed (bool): If True, print a full classification report; if False, return accuracy.
        device (str): Device to run inference on (default: "cuda").

    Returns:
        float or None:
            - If detailed=False, returns scalar accuracy value (0–1).
            - If detailed=True, prints a classification report and returns None.
    """
    model_training = model.training
    model.eval()
    pred_list = []
    true_labels = []
    for i, data in enumerate(dataloader, 0):
        try:
            images, labels = data
        except:
            images, labels, index = data
        images, labels = images.to(device), labels.to(device)
        if is_attacked:
            adv_images = method(images, labels, model, epsilon, alpha)
            preds = model(adv_images)
        else:
            preds = model(images)
        indices = torch.argmax(preds, 1)
        indices = indices.cpu()
        labels = labels.cpu()
        for j in range(len(indices)):
            pred_list.append(classes[indices[j]])
            true_labels.append(classes[labels[j]])
    if model_training:
        model.train()
    if detailed:
        print(classification_report(true_labels, pred_list, digits=4))
    else:
        return (np.array(true_labels) == np.array(pred_list)).sum() / len(true_labels)