import torchattacks
import torch
from torchmetrics.classification import Accuracy
from tqdm import tqdm
from torch.utils.data import DataLoader


class AdversarialRobustnessEval:
    def __init__(
        self,
        model,
        device: str = "cuda:0",
        num_classes: int = 1000,
        batch_size: int = 16,
        attack_name: str = "FGSM",
    ):
        self.model = model.to(device)
        self.model.eval()
        self.device = device
        self.batch_size = batch_size

        self.accuracy_top_1 = Accuracy(
            task="multiclass", num_classes=num_classes, top_k=1
        )
        self.accuracy_top_5 = Accuracy(
            task="multiclass", num_classes=num_classes, top_k=5
        )

        self.all_attacks = {
            "FGSM": torchattacks.FGSM(model, eps=3 / 255),
            "PGD": torchattacks.PGD(
                model,
                eps=0.03137254901960784,
                alpha=0.00784313725490196,
                steps=10,
                random_start=True,
            ),
            "OnePixel": torchattacks.attacks.onepixel.OnePixel(
                model, pixels=1, steps=10, popsize=10, inf_batch=128
            ),
            "Jitter": torchattacks.attacks.jitter.Jitter(
                model,
                eps=0.03137254901960784,
                alpha=0.00784313725490196,
                steps=10,
                scale=10,
                std=0.1,
                random_start=True,
            ),
            "Square": torchattacks.attacks.square.Square(
                model,
                norm="Linf",
                eps=0.03137254901960784,
                n_queries=5000,
                n_restarts=1,
                p_init=0.8,
                loss="margin",
                resc_schedule=True,
                seed=0,
                verbose=False,
            ),
            "Pixle": torchattacks.attacks.pixle.Pixle(
                model,
                x_dimensions=(2, 10),
                y_dimensions=(2, 10),
                pixel_mapping="random",
                restarts=20,
                max_iterations=10,
                update_each_iteration=False,
            ),
        }

        for name in self.all_attacks:
            self.all_attacks[name].set_normalization_used(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )
        assert attack_name in list(
            self.all_attacks.keys()
        ), f"Expected attack_name to be one of: {self.all_attacks.keys()}\nbut got: {attack_name}"
        self.attack = self.all_attacks[attack_name]
        self.attack_name = attack_name

    def run(self, dataset):
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

        all_labels = []
        all_logits = []
        all_logits_adv = []

        for batch in tqdm(dataloader):
            images, labels = batch
            adv_images = self.attack(images, labels)

            with torch.no_grad():
                y = self.model(images.to(self.device))
                y_adv = self.model(adv_images.to(self.device))
            all_logits.append(y.cpu())
            all_logits_adv.append(y_adv.cpu())

            all_labels.append(labels.cpu())
            torch.cuda.empty_cache()
        all_logits = torch.cat(all_logits, dim=0)
        all_logits_adv = torch.cat(all_logits_adv, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        results = {
            "original": {
                "accuracy_top_1": self.accuracy_top_1(all_logits, all_labels).item(),
                "accuracy_top_5": self.accuracy_top_5(all_logits, all_labels).item(),
            },
            "adversarial": {
                "accuracy_top_1": self.accuracy_top_1(
                    all_logits_adv, all_labels
                ).item(),
                "accuracy_top_5": self.accuracy_top_5(
                    all_logits_adv, all_labels
                ).item(),
            },
            "attack": {"name": self.attack_name, "info": self.attack.__repr__()},
        }
        return results
