import torch
from attacks import BIM
from data import get_someset_loader
from tester import test_transfer_attack_acc, test_acc
from defenses import DiffusionPure
from utils.seed import set_seed
from models import BaseNormModel, resnet50, ClassSelectionModel
from models.unets import get_guided_diffusion_unet
from torchvision import transforms

set_seed(1)

loader = get_someset_loader('./resources/RestrictedImageNet256',
                            './resources/RestrictedImageNet256/gt.npy',
                            batch_size=1,
                            transform=transforms.Compose([
                                transforms.Resize((256, 256)),
                                transforms.ToTensor(),
                            ])
                            )
device = torch.device('cuda')
unet = get_guided_diffusion_unet()
unet.load_state_dict(torch.load('../../../resources/checkpoints/guided_diffusion/256x256_diffusion_uncond.pt'))


class OnlyMeanModel(torch.nn.Module):
    def __init__(self):
        super(OnlyMeanModel, self).__init__()
        self.unet = unet

    def forward(self, x, t, *args, **kwargs):
        return self.unet(x, t, *args, **kwargs)[:, :3, :, :]


diffpure = DiffusionPure(mode='sde', grad_checkpoint=True,
                         unet=OnlyMeanModel(),
                         model=ClassSelectionModel(BaseNormModel(resnet50(pretrained=True)),
                                                   target_class=(151, 281, 30, 33, 80, 365, 389, 118, 300)),
                         img_shape=(3, 256, 256),
                         ).eval().requires_grad_(False)

loader = [item for i, item in enumerate(loader) if 0 <= i < 16]

# test_acc(diffpure, loader)
# attacker = BIM([diffpure], step_size=1 / 255, total_step=80, eot_step=20, epsilon=4 / 255,
#                norm='Linf',
#                eot_batch_size=1)
# test_transfer_attack_acc(
#     attacker,
#     loader,
#     [diffpure],
# )
