from typing import Iterable
import torch.nn as nn
from tqdm import tqdm

# prepare your pytorch model as "model"
# prepare a batch of data and label as "cln_data" and "true_label"
# ...

from advertorch.attacks import LinfPGDAttack
import torch

def compute_accuracy(logits, labels):
    predicted = torch.argmax(logits, dim=1)  # Get the predicted class by taking the argmax along the class dimension
            
    return (predicted == labels).sum().item() / labels.shape[0]


def eval_robustness(
    model: nn.Module, dataloader: Iterable, epsilon: float = 3 / 255, progress=True, max_num_batches=10
):

    adversary = LinfPGDAttack(
        model,
        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
        eps=epsilon,
        nb_iter=20,
        eps_iter=0.01,
        rand_init=True,
        clip_min=0.0,
        clip_max=1.0,
        targeted=False,
    )

    # accuracies_original = []
    accuracies_adv = []
    num_batches = 0

    for images, labels in tqdm(dataloader, disable=not(progress)):

        # adv_untargeted = adversary.perturb(images.float(), labels)
        adversary.targeted = True
        adv_targeted = adversary.perturb(images, labels)

        # logits_original = model(images.float())
        logits_adv_targeted = model(adv_targeted)
        # logits_adv_untargeted = model(adv_untargeted)

        # loss_original = nn.functional.cross_entropy(
        #     logits_original,
        #     labels
        # )
        # loss_adv_targeted = nn.functional.cross_entropy(
        #     logits_adv_targeted,
        #     labels
        # )
        # loss_adv_untargeted = nn.functional.cross_entropy(
        #     logits_adv_untargeted,
        #     labels
        # )

        # acc_original = compute_accuracy(
        #     logits=logits_original,
        #     labels=labels
        # )
        acc_adv_targeted = compute_accuracy(
            logits=logits_adv_targeted,
            labels=labels
        )
        # acc_adv_untargeted = compute_accuracy(
        #     logits=logits_adv_untargeted,
        #     labels=labels
        # )

        # accuracies_original.append(acc_original)
        accuracies_adv.append(acc_adv_targeted)
        num_batches += 1
        if max_num_batches is not None :
            if num_batches == max_num_batches:
                break

    return  sum(accuracies_adv)/len(accuracies_adv)
