from common.plot_utils import plot_rocs
from common.util import get_prediction_transformers, extraction_transformers, load_data, load_model, get_prediction_by_bs
from depth.utils import *
import numpy as np
import torch
import argparse
import os

def get_statistics(args, X_adv):
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x_train, y_train, x_test, y_test = load_data(dataset_name=args.dataset_name, transformer=args.trans, data_dir=args.data_dir)
    if args.dataset_name == 'tiny':
        model = load_model(dataset_name=args.dataset_name, checkpoints_dir='{}/{}/'.format(args.checkpoints_dir, args.dataset_name),
                       device=args.device, transformer=args.trans, checkpoint_name='large_vit_384.pth')
    else :
        model = load_model(dataset_name=args.dataset_name, checkpoints_dir='{}/{}/'.format(args.checkpoints_dir, args.dataset_name),
                       device=args.device, transformer=args.trans, ckpt_n=args.checkpoint_number, model_type=args.model_type)
    model.eval()
    if args.trans:
        preds_adv = extraction_transformers(args, X_adv, model, desc='attack').detach().cpu().numpy()
    else :
        if torch.cuda.is_available():
            preds_adv = get_prediction_by_bs(X=X_adv, model=model, num_classes=10)
        else:
            preds_adv = model(torch.tensor(X_adv))
    correct = np.where(preds_adv.argmax(axis=1) != y_test.argmax(axis=1), 1, 0)
    return correct

def collect_decision_by_thresholds(args, probs, thrs_size):
    thrs = np.linspace(probs.min(), probs.max(), thrs_size)
    decision_by_thr = np.zeros((len(probs), thrs_size))
    # An example is detected as adversarial is the prediction is above the threshold, natural if not
    for i in range(thrs_size):
        thr = thrs[i]
        # As the depth is a similarity score, the bigger the value is, the more likely it is natural. For the others, the smaller the value, the more likely it is natural.
        if args.method_name == 'depth':
            y_pred = np.where(probs < thr, 1, 0)
        else :
            y_pred = np.where(probs > thr, 1, 0)
        decision_by_thr[:, i] = y_pred

    return decision_by_thr


def collect_decision(args, method_name, attack, thrs_size=200):
    correct = []
    print(attack)
    # Download Adversarial Samples
    X_test_adv = load_adv(args.adv_path, dataset_name=args.dataset_name, attack=attack, transformers=args.trans)

    # Compute whether the adversarial examples successfully fools the target classifier or not, and save the decision
    correct.append(get_statistics(args=args, X_adv=X_test_adv))
    # Download whether this is an adversarial or a natural sample
    labels = np.concatenate((np.zeros(len(X_test_adv),), np.ones(len(X_test_adv),)))

        # Download the detector's output (distance or probability)
    res_dir = os.path.join(args.res_dir, args.method_name)
    if method_name == 'depth':
        
        probs = np.load(os.path.join(res_dir, 'prediction/probs_{}{}_predicted.npy'.format(args.dataset_name, attack)))
    else :
        probs = np.load(os.path.join(res_dir, 'probs_{}{}_all.npy'.format(args.dataset_name, attack)))

    if method_name == 'fs' or method_name == 'depth':
        proba = probs
    elif method_name == 'magnet':
        if probs.shape[0] > 1:
            proba = probs[1]
        else:
            proba = probs[0]

            # We reshape the variable that if the noisy sample is successful or not
    ca = np.ones((2 * len(X_test_adv), thrs_size))
    for j in range(thrs_size):
        ca[len(X_test_adv):, j] = np.asarray(correct)
            # We multiply the adversarial decision by the successfulness of the attack to discard the non-adversarial samples.
    decision = collect_decision_by_thresholds(args, proba, thrs_size)
    decision_by_thr_adv = decision * ca
    decision_by_thr_nat = decision

    correct = np.transpose(correct)

    # We compute the statistics of the successfulness of the attacks
    mean_success_adv = correct.mean()
    print('Total. Number of Successful Attacks per Natural Sample: ', mean_success_adv)

    # We gather the true label (i.e. 0 if the sample is natural, 1 if it is not).
    labels = np.reshape(labels, (-1,))
    labels_tot = (np.ones((thrs_size, len(labels))) * labels).transpose()
    correct = np.concatenate((np.zeros(correct.shape), correct), axis=0)
    correct_all = np.zeros(decision_by_thr_adv.shape)
    # We compute the number of times a natural sample has a successful adversarial examples.
    for i in range(decision_by_thr_adv.shape[1]):
        correct_all[:, i] = correct.sum(1)

    # The sample is considered as true positive iff there is at least one successful adversarial examples (i.e. correct_all > 0) and iff the detector detects all of successful adversarial examples (i.e. decision_by_thr_adv = correct_all).
    tp = np.where((decision_by_thr_adv == correct_all) & (correct_all > 0), 1, 0)
    # The sample is considered as false positive if it was a natural example (i.e. labels_tot==0) and if it detected as adversarial (the decision is above 0).
    fp = np.where((decision_by_thr_nat > 0) & (labels_tot == 0), 1, 0)
    # The sample is considered as false positive if it was a natural example (i.e. labels_tot==0) and if it detected as natural (the natural decision is 0).
    tn = np.where((decision_by_thr_nat == 0) & (labels_tot == 0), 1, 0)
    # The sample is considered as false negative iff here is at least one successful adversarial examples (i.e. correct_all > 0) and if it detected less examples than there is (decision_by_thr_adv < correct_all).
    fn = np.where((decision_by_thr_adv < correct_all) & (correct_all > 0), 1, 0)

    # We sum over all the examples.
    tpr = tp.sum(axis=0) / (tp.sum(axis=0) + fn.sum(axis=0))
    fpr = fp.sum(axis=0) / (fp.sum(axis=0) + tn.sum(axis=0))
    # We plot the roc and print the AUROC value
    results_dir = 'plots/{}/'.format(args.dataset_name)
    if not os.path.exists(results_dir + method_name):
        os.makedirs(results_dir + method_name)
    plot_rocs([fpr], [tpr], ["AUROC"], ['red'], '{}{}/roc_{}.pdf'.format(results_dir, method_name, attack))

    return fpr, tpr

def main(args, method_name, attack):
    collect_decision(args, method_name, attack)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-d', '--dataset_name',
        help="Dataset to use; either cifar10 or cifar100",
        required=True, type=str
    )
    parser.add_argument(
        '-m', '--method_name',
        help="Dataset to use; either fs, magnet or depth",
        required=True, type=str
    )
    parser.add_argument(
            '--trans',
            default=True
    )
    parser.add_argument(
        '-l', '--loss',
        help="Loss to use; either 'CE', 'KL', 'Rao', or 'g'",
        type=str, default='CE'
    )
    parser.add_argument(
        '--attack', '-a',
        type=str, default='pgdi'
    )
    parser.add_argument(
        '--epsilon', '-e',
        type=str, default=0.03125
    )

    parser.add_argument(
        '--model_type', '-type',
        type=str, default="ViT-B_16"
    )

    parser.add_argument(
            '--checkpoints_dir',
            type=str, default="model_ckpt"
    )
    
    parser.add_argument(
            '--res_dir',
            type=str, default="detectors/results/"
    )
    
    parser.add_argument(
            '--adv_path',
            type=str, default="detectors/adv_data/"
    )
    parser.add_argument(
            '--data_dir',
            type=str, default="data/"
    )
    
    args = parser.parse_args()
    if args.attack.startswith('fgsm') or args.attack.startswith('pgd') or args.attack.startswith('bim'):
        attack = "{}_{}_{}".format(args.loss, args.attack, args.epsilon)
    else :
        attack = ["_{}".format(args.attack)]
    if args.dataset_name == 'cifar10' :
        args.checkpoint_number = 70000
    elif args.dataset_name == 'cifar100':
        args.checkpoint_number = 8000

    main(args, args.method_name, attack)

