import argparse
from argparse import ArgumentParser
import sys
from time import time
import yaml
import random
import numpy as np
import torch
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
import json
from models.image_models import HolzClassifier
from utils.dataload import DeepFakeDatasetPathList
from torch.utils.data import Dataset, DataLoader
from attacks.image_attacks import *
import torchaudio
import torch.nn as nn
import os
from tqdm import tqdm
from PIL import Image


BATCH_SIZE = 1024


def main(args):

    if args.device == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.device) if (torch.cuda.is_available()) else "cpu")

        data_transform = transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(128),
            transforms.ToTensor()
        ])

        data_loader = DeepFakeDatasetPathList
        real_images_train = []
        for path, subdirs, files in os.walk("./data/real/train/"):
            for name in files:
                if name.endswith('.png') or name.endswith('.jpg'):
                    real_images_train.append((os.path.join(path, name), 0))
        fake_images_train = []
        for path, subdirs, files in os.walk("./data/fake/train/"):
            for name in files:
                if name.endswith('.png') or name.endswith('.jpg'):
                    fake_images_train.append((os.path.join(path, name), 1))

        real_images_val = []
        for path, subdirs, files in os.walk("./data/real/val/"):
            for name in files:
                if name.endswith('.png') or name.endswith('.jpg'):
                    real_images_val.append((os.path.join(path, name), 0))
        fake_images_val = []
        for path, subdirs, files in os.walk("./data/fake/val/"):
            for name in files:
                if name.endswith('.png') or name.endswith('.jpg'):
                    fake_images_val.append((os.path.join(path, name), 1))

        train_image_dataset = data_loader(real_images_train, fake_images_train, data_transform)

        val_image_dataset = data_loader(real_images_val, fake_images_val, data_transform)

        train_loader = DataLoader(train_image_dataset, BATCH_SIZE, shuffle=True)

        val_loader = DataLoader(val_image_dataset, BATCH_SIZE, shuffle=True)

        mean_file = torch.load("./mean.pt", map_location='cpu').to(device)
        var_file = torch.load("./var.pt", map_location='cpu').to(device)

        torch.save(torch.ones(1, 3, 128, 128), './models/at/masks.pt')
        masks = torch.load(f"./models/{args.model_name}/masks.pt", map_location='cpu').to(device)

        models = []
        for i in range(len(masks)):
            model = HolzClassifier(mean_file, var_file, masks[i].to(device))
            model.to(device)
            models.append(model)

        for imod in range(len(models)):
            criterion = torch.nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(models[imod].parameters(), lr=0.001)

            best_val_auc = 0

            for epoch in range(20):
                loss_track = []
                for i, (images, labels) in enumerate(tqdm(train_loader)):

                    labels = labels.to(device)

                    filter = labels == 1

                    _, adv_images = pgd_max_single_bpda(models[imod], images[filter], device,
                                                        steps=10,
                                                        epsilon=0.004,
                                                        step_size=0.001,
                                                        labels=labels[filter])
                    images[filter] = adv_images

                    images = images.to(device)
                    outputs = models[imod](images)

                    loss = criterion(outputs, labels)

                    optimizer.zero_grad()
                    loss.backward()

                    loss_track.append(loss.cpu().data)
                    optimizer.step()

                probs, actual = [], []
                loss_track = []
                for images, labels in val_loader:

                    labels = labels.to(device)

                    filter = labels == 1

                    _, adv_images = pgd_max_single_bpda(models[imod], images[filter], device,
                                                        steps=10,
                                                        epsilon=0.004,
                                                        step_size=0.001,
                                                        labels=labels[filter])
                    images[filter] = adv_images

                    images = images.to(device)
                    outputs = models[imod](images)

                    loss_track.append(criterion(outputs, labels).cpu().data)

                    y_probs = (F.softmax(outputs, dim=1).data[:, 1]).double()

                    probs += y_probs.detach().cpu().numpy().tolist()
                    actual += labels.detach().cpu().numpy().tolist()

                auc = roc_auc_score(actual, probs)

                if auc > best_val_auc:
                    torch.save(models[imod].state_dict(), f"./models/{args.model_name}/{imod}.pt")
                    best_val_auc = auc


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--model_name', type=str, default='d3s4')
    main(parser.parse_args())
