from pathlib import Path

import torch
from PIL import Image
from torch.nn import Module
from torchattacks import FGSM
from torchvision.transforms.v2 import ToTensor

from applications.common import IMG_MODELS


class Model(Module):
    def __init__(self, model_name: str, device: int):
        super().__init__()
        model_card = IMG_MODELS[model_name]
        self.processor = model_card.processor_class.from_pretrained(
            model_card.model_name
        )
        self.model = model_card.model_class.from_pretrained(model_card.model_name).to(
            device=device
        )
        self.post_processor = model_card.post_processor
        self.device = device

    def forward(self, x):
        data = self.processor(images=x, return_tensors="pt").pixel_values
        return self.post_processor(self.model(data.to(device=self.device)))


def main(model_name: str, image_name: str, class_number: int, device: int):
    model = Model(model_name, device)
    image = Image.open(Path().cwd().parent / "images" / f"{image_name}.png")

    fgsm = FGSM(model, eps=0.3)
    fgsm(
        ToTensor()(image).unsqueeze(0).to(device=device),
        torch.tensor([class_number], device=device),
    )


if __name__ == '__main__':
    main("imagenet", "cats", 3, 0)
