from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch_dct as dct
import torch.nn.functional as F
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import json
import sys
import operator as op
from functools import reduce

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

import yaml
import json
from models.image_models import HolzClassifier, autoattack_wrapper
from utils.dataload import DeepFakeDatasetPathList
import attacks.image_attacks as image_attacks

from IPython import embed


sys.path.insert(0, './auto-attack/')

from autoattack import AutoAttack


def main(args):
    if args.device == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.device) if (torch.cuda.is_available()) else "cpu")

    import attacks.image_attacks as attacks
    data_transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    BATCH_SIZE = 64

    real_images_test = [("./data/real/test/" + x, 0) for x in os.listdir("./data/real/test/")]
    fake_images_test = [("./data/fake/test/" + x, 1) for x in os.listdir("./data/fake/test/")]

    real_image_dataset = DeepFakeDatasetPathList(real_images_test, [], data_transform)
    fake_image_dataset = DeepFakeDatasetPathList(fake_images_test, [], data_transform)

    real_loader = DataLoader(real_image_dataset, BATCH_SIZE, shuffle=True)
    fake_loader = DataLoader(fake_image_dataset, BATCH_SIZE, shuffle=True)

    mean_file = torch.load("./mean.pt", map_location="cpu")
    var_file = torch.load("./var.pt", map_location="cpu")

    #### LOAD MODELS
    model_list = []
    masks = torch.load(f"./models/{args.model_name}/masks.pt")
    for i in range(len(masks)):
        model = HolzClassifier(mean_file, var_file, masks[i])
        model.load_state_dict(torch.load(f"./models/{args.model_name}/{i}.pt"))
        model.to(device)
        model.eval()
        model_list.append(model)

    if args.model_name == "at":
        threshold = 0.0068
    elif args.model_name == "d3s4":
        threshold = 0.08

    images, labels = [], []
    for x in fake_loader:
        images.append(x[0][0])
        labels.append(x[1][0])


    adversary = AutoAttack(autoattack_wrapper(model_list).forward, norm=args.norm, eps=args.eps, version='custom',
                           device=device, threshold=threshold)
    adversary.attacks_to_run = [args.attack]  # ,'fab','square']
    adversary.apgd.n_iter = args.steps
    _ = adversary.run_standard_evaluation(torch.stack(images), torch.stack(labels), bs=BATCH_SIZE)




if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--model_name', type=str, default='d3s4')
    parser.add_argument('--attack', type=str, default='apgd-ce')
    parser.add_argument('--norm', type=str, default='Linf')
    parser.add_argument('--eps', type=float, default=0.004)
    parser.add_argument('--steps', type=int, default=50)
    main(parser.parse_args())
