from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import torch_dct as dct
import torch.nn.functional as F
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import json
import sys
import operator as op
from functools import reduce
# from torchaudio import transforms

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

import yaml
import json
from models.image_models import HolzClassifier, autoattack_wrapper
from utils.dataload import DeepFakeDatasetPathList
import attacks.image_attacks as image_attacks

from IPython import embed

sys.path.insert(0, './auto-attack/')

from autoattack import AutoAttack

## SALIENCY MASK
device = torch.device("cuda:0")

BATCH_SIZE = 1

##### LOAD DATA

import attacks.image_attacks as attacks
data_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(128),
    transforms.ToTensor()
])

real_images_test = [("./data/real/test/" + x, 0) for x in os.listdir("./data/real/test/")]
fake_images_test = [("./data/fake/test/" + x, 1) for x in os.listdir("./data/fake/test/")]

real_image_dataset = DeepFakeDatasetPathList(real_images_test, [], data_transform)
fake_image_dataset = DeepFakeDatasetPathList(fake_images_test, [], data_transform)

real_loader = DataLoader(real_image_dataset, BATCH_SIZE, shuffle=True)
fake_loader = DataLoader(fake_image_dataset, BATCH_SIZE, shuffle=True)

#### LOAD MODEL



pil_to_tens = transforms.ToTensor()

mean_file = torch.load("./mean.pt", map_location="cpu")
var_file = torch.load("./var.pt", map_location="cpu")

freq_mask = torch.load("./models/at/masks.pt")
model = HolzClassifier(mean_file, var_file, freq_mask)
model.load_state_dict(torch.load("./models/at/0.pt", map_location="cpu"))
model.eval()
model.to(device)

thresh = 0.0068

##### ATTACK

epsilon = 1
step_size = 0.001

steps = 1000

existingAggregate = (0, 0, 0)

def update(existingAggregate, newValue):
    (count, mean, M2) = existingAggregate
    count += 1
    delta = newValue - mean
    mean += delta / count
    delta2 = newValue - mean
    M2 += delta * delta2
    return (count, mean, M2)


def finalize(existingAggregate):
    (count, mean, M2) = existingAggregate
    if count < 2:
        return float("nan")
    else:
        (mean, variance, sampleVariance) = (mean, M2 / count, M2 / (count - 1))
        return (mean, variance, sampleVariance)

for i, (image, labels) in enumerate(tqdm(fake_loader)):

    pert = torch.FloatTensor(*image.shape).uniform_(-0.004, 0.004).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    image = image.to(device)
    labels = labels.to(device)
    final_pert_images = torch.clone(image.cpu())

    full_pert_set = torch.clamp(image + pert, 0.0, 1.0)

    y_loss = (F.softmax(model(full_pert_set), dim=1).data[:, 1]).double().cpu()

    for i in range(steps):
        pert_image = full_pert_set
        pert_image.requires_grad_()

        outputs = model(pert_image)

        loss_s = criterion(outputs, torch.ones(pert_image.size()[0], dtype=torch.long).to(device))

        loss_s.backward()

        grad = pert_image.grad

        pert_image = pert_image.detach() + step_size * grad.sign()

        pert_image = torch.min(torch.max(pert_image, image - epsilon),
                               image + epsilon)
        pert_image = torch.clamp(pert_image, 0.0, 1)

        full_pert_set = pert_image

        final_pert_images = full_pert_set.cpu()

        y_loss = (F.softmax(model(full_pert_set), dim=1).data[:, 1]).double().cpu()

        if y_loss < thresh:
            break

    sal = torch.abs(grad.cpu()) * torch.abs((final_pert_images - image.cpu()))

    existingAggregate = update(existingAggregate, sal.squeeze(0))

final_values = finalize(existingAggregate)

final_values = final_values[0]
norm_final_values = (final_values - torch.min(final_values)) / (torch.max(final_values) - torch.min(final_values))

sal = norm_final_values

n = 4

masks = [torch.zeros(3,128,128) for i in range(n)]

list_indices = []
for i in range(3):
    for j in range(128):
        for k in range(128):
            list_indices.append((sal[i,j,k],i,j,k))

list_indices.sort()

for l,(v,i,j,k )in enumerate(tqdm(list_indices)):
    masks[l%n][i,j,k] = 1

torch.save(torch.stack(masks),"./models/d3s4/masks.pt")