import json
import argparse

import torch
import numpy as np

import foolbox.attacks as fa
from foolbox import PyTorchModel, accuracy
from foolbox.criteria import Misclassification

from models import DIM
from utils import get_dev, get_full_dataloader, boolean_string



def attack_batch(fmodel, images, labels, attack, epsilons, print_batch=False):
    clean_acc = accuracy(fmodel, images, labels)
    criterion = Misclassification(labels) 
    _, _, success = attack(fmodel, images, criterion, epsilons=epsilons)
    robust_accuracy = 1 - success.double().mean(axis=-1)
    if print_batch:
        print("", flush=True)
        print(f"clean accuracy:  {clean_acc * 100:.1f} %", flush=True)
        print(f"robust accuracy for perturbations with", flush=True)
        for eps, acc in zip(epsilons, robust_accuracy):
            print(f"  epsilon ≤ {eps:<6}: {acc.item() * 100:4.1f} %", flush=True)
    return clean_acc, robust_accuracy


def main(args):
    # set device & attack
    dev = get_dev(args.device)
    attack = fa.L2DeepFoolAttack()
    epsilons = np.arange(0., 15., step=0.5)

    # load DIM model
    cfg_path_base = "checkpoints/DIMcfg"
    denoiser = "ae_pepper"
    column = "column3"
    path_denoiser = f"{cfg_path_base}/{denoiser}.json"
    with open(path_denoiser, 'r') as f:
        config_denoiser = json.loads(f.read())
    path_column = f"{cfg_path_base}/{column}.json"
    with open(path_column, 'r') as f:
        config_column = json.loads(f.read())

    model = DIM(config_denoiser, config_column, dev, 
                binary_cutoff_1= False, binary_cutoff_2= False)

    model.to(torch.device(dev))
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1), device=dev)

    # load data & attack
    testLoader = get_full_dataloader(bs=args.batch_size)
    total_cacc = 0.
    total_racc = 0.
    for i_batch , (images, labels) in enumerate(testLoader):
        if args.print_batch:
            print(f"for batch {i_batch}", flush=True)
        else:
            print(f"----batch[{i_batch+1}/{len(testLoader)}]", end="", flush=True)
        images = images.to(dev)
        labels = labels.to(dev)
        batch_size = labels.shape[0]

        clean_acc, robust_acc = attack_batch(fmodel, images, labels, attack, epsilons,
                                             print_batch=args.print_batch)

        total_cacc += clean_acc * (batch_size/len(testLoader.dataset))
        total_racc += robust_acc * (batch_size/len(testLoader.dataset))

    print("", flush=True)
    print(f"total clean accuracy:  {total_cacc * 100:.1f} %", flush=True)
    print(f"total robust accuracy for perturbations with", flush=True)
    for eps, acc in zip(epsilons, total_racc):
        print(f"  epsilon ≤ {eps:<6}: {acc.item() * 100:4.1f} %", flush=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", type=int, default=19)
    parser.add_argument("--device", type=str, default="cuda:0",
                        choices=["cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3"])
    parser.add_argument("--batch_size", type=int, default=3000)
    parser.add_argument("--print_batch", type=boolean_string, default=False)

    args = parser.parse_args()
    main(args)