import torch
from attacks import BIM
from data import get_someset_loader, get_NIPS17_loader, get_CIFAR10_test
from tester import test_transfer_attack_acc, test_acc
from defenses import DiffusionPure
from utils.seed import set_seed
from models import BaseNormModel, resnet50, ClassSelectionModel
from models.unets.EDM import get_edm_imagenet_64x64_cond
from torchvision import transforms
from utils.ImageHandling import save_image

set_seed(1)

loader = get_someset_loader('./resources/RestrictedImageNet256',
                            './resources/RestrictedImageNet256/gt.npy',
                            batch_size=1,
                            transform=transforms.Compose([
                                transforms.Resize((64, 64)),
                                transforms.ToTensor(),
                            ])
                            )


class Post():
    def __init__(self):
        self.transform = transforms.Resize((256, 256))

    def __call__(self, x):
        x = self.transform(x)
        return x


device = torch.device('cuda')
unet = get_edm_imagenet_64x64_cond()
unet.load_state_dict(torch.load('../../../resources/checkpoints/EDM/edm-imagenet-64x64-cond.pt'))

diffpure = DiffusionPure(mode='edm',
                         grad_checkpoint=True,
                         unet=unet,
                         model=ClassSelectionModel(BaseNormModel(resnet50(pretrained=True)),
                                                   target_class=(151, 281, 30, 33, 80, 365, 389, 118, 300)),
                         post_transforms=Post(),
                         ).eval().requires_grad_(False)
#
#
# class FakeModel(torch.nn.Module):
#     def __init__(self):
#         super(FakeModel, self).__init__()
#
#     def forward(self, x):
#         return x + torch.randn_like(x) * 0.5
#
#
# diffpure.diffusion = FakeModel()

loader = [item for i, item in enumerate(loader) if 0 <= i < 32]

# test_acc(diffpure, loader)
attacker = BIM([diffpure], step_size=1 / 255, total_step=80, eot_step=20, epsilon=4 / 255,
               norm='Linf',
               eot_batch_size=2)
test_transfer_attack_acc(
    attacker,
    loader,
    [diffpure],
)
