import pandas as pd
import torch
import numpy as np
from pathlib import Path

import utilis
import wandb
import argparse
import random
import torch.nn as nn
from config import get_config
from networks import get_network
from datasets import load_data
import os

os.environ["WANDB_MODE"]="disabled"
wandb.init(project='Bound', reinit=True)
parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--smooth', type=bool, default=True)  # <- fixed
parser.add_argument('--adv_load', type=bool, default=True)  # <- fixed
parser.add_argument('--device', default=1, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=50, type=int, help='Amount of samples used during inference')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--max_eps', type=float, default=.2, help='Maximal attack strength, we take 10 equidistant steps to reach the maximal strength')
parser.add_argument('--attack', type=str, default='FGM')  # <- fixed
parser.add_argument('--attack_samples', type=int, default=10)  # <- fixed
args = parser.parse_args()
args = get_config(args)
wandb.config.update(args)
torch.cuda.set_device(args.device)
wandb.run.name = f'''{args.net}_{args.droprate}_{args.attack_samples}_{wandb.run.id}'''
print(args.droprate)

def main(args):
    # get data
    _, _, red_test_loader = load_data(args.dataset, args.batch_size, args.root_dir)

    # get model
    if args.net == 'ResNet':
        model = get_network(args)
    else:
        model = get_network(args)
        args.droprate = 0

    # load model
    parameter = torch.load(Path(args.root_dir,
                                f'''models/{args.dataset}/model_smooth_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}_{args.smooth_level}.bin'''), map_location='cpu')
    model.load_state_dict(parameter)
    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # different seed than during training
    args.randseed += 1000
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    X, advs, deltas, epsilon, targets = load_adv_samples(model, red_test_loader, args)
    # evaluate adversarial examples
    args.randseed += 1000
    for idx in range(len(X)):
        eval_adv_samples(model, X[idx], advs[idx], deltas[idx], targets[idx], epsilon[idx])


def load_adv_samples(model, data_loader, args):
    all_advs = []
    all_data = []
    all_delta = []
    all_eps = []
    all_targets = []
    model.cuda()
    max_eps = args.max_eps
    step = np.round(max_eps / 10, 4)
    for eps in np.arange(0, max_eps, step):
        eps = np.round(eps, 4)
        file_name = utilis.get_string(args, eps)
        advs = np.load(Path(args.root_dir,
                            f'''data/adversarial/advs_smooth_{file_name}_{args.smooth_level}.npy'''))
        deltas = np.load(Path(args.root_dir,
                              f'''data/adversarial/deltas_smooth_{file_name}_{args.smooth_level}.npy'''))
        epsilon = np.load(Path(args.root_dir,
                            f'''data/adversarial/epsilons_all_smooth_{file_name}_{args.smooth_level}.npy'''))
        targets = np.load(Path(args.root_dir,
                              f'''data/adversarial/benign_target_smooth_{file_name}_{args.smooth_level}.npy'''))
        benign = np.load(Path(args.root_dir,
                              f'''data/adversarial/benign_data_smooth_{file_name}_{args.smooth_level}.npy'''))

        advs = torch.from_numpy(advs)
        all_advs.append(advs)
        all_data.append(benign)
        all_delta.append(deltas)
        all_targets.append(targets)
        all_eps.append(epsilon)
    return all_data, all_advs, all_delta, all_eps, all_targets


def eval_adv_samples(model, true_data, advs, deltas, targets, eps):
    deltas_norm = np.linalg.norm(deltas.reshape(deltas.shape[0], -1), ord=2, axis=1)
    preds_inter = []
    r_values_lin = []
    r_values_smooth = []
    for i in range(100):
        if i % 10 == 0:
            print(i)
        pred_label_adv, min_rvalue_linear, min_rvalue_smooth, min_r_idx_linear, min_r_idx_smooth, zaehler, norm, alpha = utilis.predict_model_smooth(
            model, advs[i].float().cuda(), torch.from_numpy(true_data[i]).float().cuda(), targets[i], args.n_samples,
            deltas[i], args, eps[i])
        preds_inter.append(pred_label_adv)
        r_values_lin.append(min_rvalue_linear)
        r_values_smooth.append(min_rvalue_smooth)
    pred = np.array(preds_inter)
    r_values_lin = np.array(r_values_lin)
    r_values_smooth = np.array(r_values_smooth)
    print(f'accuracy : {1 - (np.count_nonzero((pred - targets[:100])) / len(pred))} for eps = {eps[0]}')
    wandb.log({'adv_accuracy': 1 - (np.count_nonzero((pred - targets[:100])) / len(pred)), 'step': eps[0]})
    print(f'r accuracy : {np.count_nonzero(r_values_lin - deltas_norm[:100] > 0) / 100} for eps = {eps[0]}')
    wandb.log({'r_value_acc_lin': np.count_nonzero(r_values_lin - deltas_norm[:100] > 0) / 100, 'step': eps[0]})
    print(f'smooth r accuracy : {np.count_nonzero(r_values_smooth - deltas_norm[:100] > 0) / 100} for eps = {eps[0]}')
    wandb.log({'r_value_acc_smooth': np.count_nonzero(r_values_smooth - deltas_norm[:100] > 0) / 100, 'step': eps[0]})
    data_pred = np.column_stack((eps[:100], targets[:100], deltas_norm[:100], pred, r_values_lin, r_values_smooth))
    info = pd.DataFrame(data=data_pred,
                        columns=['eps', 'real', 'deltas_norm', 'predicted_adv', 'r_value_lin', 'r_values_smooth'])
    info.to_pickle(Path(args.root_dir,
                        f'''results/r_value_adv_100_smooth_{args.net}_{args.dataset}_{args.droprate}_{args.n_samples}_FGM_{args.attack_samples}_{eps[0]}_{args.randseed}.h5'''))


if __name__ == "__main__":
    main(args)
