import pandas as pd
import torch
import numpy as np
from pathlib import Path

import utilis
import wandb
import argparse
import random
import torch.nn as nn
from config import get_config
from networks import get_network
import os

os.environ["WANDB_MODE"]="disabled"
wandb.init(project='Bound', reinit=True)
parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--smooth', type=bool, default=False) # <- fixed
parser.add_argument('--adv_load', type=bool, default=True) # <- fixed
parser.add_argument('--device', default=0, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=100, type=int, help='Amount of samples used during inference')
parser.add_argument('--attack_iteration_number', default=100, type=int, help='Only applicable for PGD: number of iterations')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--stoch_varianz', default=0.05, type=float, help='Only applicable for SNN models - variance of noise added to the input')
parser.add_argument('--attack', type=str, default='CW',
                    choices=['FGM', 'PGD', 'CW', 'margin', 'FGSM', 'strong'], help='Different attacks')
parser.add_argument('--attack_samples', type=int, default=100, help='Number of samples used during the attack')

args = parser.parse_args()
args = get_config(args)
wandb.config.update(args)

torch.cuda.set_device(args.device)

wandb.run.name = f'''{args.net}_{args.n_samples}_{args.attack_samples}_{wandb.run.id}'''


def main(args):
    # get model
    model = get_network(args)

    # load model
    if args.smooth == True:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_smooth_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}_{args.smooth_level}.bin'''))
    elif args.net=='SNN':
        parameter = torch.load(Path(args.root_dir,
                                                f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.layer}_{args.stoch_varianz}.bin'''))
    else:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}.bin'''))
    model.load_state_dict(parameter)

    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # different seed than during training
    args.randseed += 1000
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    X, advs, deltas, epsilon, targets = load_advs(model, args)
    # different seed than during adversarial creation
    args.randseed += 1000*6
    # evaluate adversarial examples
    eval_adv_samples(model, X, advs, deltas, targets, epsilon)



def load_advs(model, args):
    model.cuda()
    if args.net=='ResNet':
        eps=0.3
        args.max_eps=1.0
    elif args.dataset=='FashionMNIST':
        eps=1.5
        args.max_eps=2.5
    file_name= utilis.get_string(args, eps)
    advs = np.load(Path(args.root_dir,
                        f'''data/adversarial/advs_{file_name}.npy'''))
    deltas = np.load(Path(args.root_dir,
                          f'''data/adversarial/deltas_{file_name}.npy'''))
    X = np.load(Path(args.root_dir,
                f'''data/adversarial/benign_data_{file_name}.npy'''))
    targets = np.load(Path(args.root_dir,
                f'''data/adversarial/benign_target_{file_name}.npy'''))
    epsilon = np.load(Path(args.root_dir,
                f'''data/adversarial/epsilons_all_{file_name}.npy'''))

    return X, advs, deltas, epsilon, targets


def eval_adv_samples(model, true_data, advs, deltas, targets, eps ):
    deltas_norm =np.linalg.norm(deltas.reshape(deltas.shape[0], -1), ord=2, axis=1)
    preds_inter = []
    r_values = []
    positions = []
    pred_diff = []
    angles = []
    norm_pred = []
    for i in range(1000):
        predicted, r_value, position, zaehler, norm, alpha = utilis.predict_model(model, torch.from_numpy(advs[i]).float().cuda(), torch.from_numpy(true_data[i]).float().cuda(), targets[i], args.n_samples, deltas[i], args)
        pred_diff.append(zaehler[position])
        norm_pred.append(norm[position])
        angles.append(alpha[position])
        preds_inter.append(predicted)
        r_values.append(r_value)
        positions.append(position)
        if i %100==0:
            print(i)
    pred = np.array(preds_inter)
    norm_pred = np.array(norm_pred)
    angles = np.array(angles)
    pred_diff = np.array(pred_diff)
    r_values = np.array(r_values)
    positions = np.array(positions)
    print(f'accuracy : {1 - (np.count_nonzero((pred - targets[:1000])) / len(pred))} for eps = {eps[0]}')
    wandb.log({'adv_accuracy': 1 - (np.count_nonzero((pred - targets[:1000])) / len(pred)), 'step': eps[0]})
    print(f'r accuracy : {np.count_nonzero((r_values) - deltas_norm[:1000] > 0) / 1000} for eps = {eps[0]}')
    wandb.log({'r_value_acc': np.count_nonzero((r_values) - deltas_norm[:1000] > 0) / 1000, 'step': eps[0]})
    data_pred = np.column_stack(
        (eps[:1000], targets[:1000], pred, positions, deltas_norm[:1000], norm_pred, angles, pred_diff, r_values))
    info = pd.DataFrame(data=data_pred,
                        columns=['eps', 'real', 'predicted_adv', 'positions', 'deltas_norm', 'norm_pred', 'angle',
                                 'pred_diff', 'r_value'])

    file_name = utilis.get_string(args, eps[0])
    print(file_name)
    info.to_pickle(Path(args.root_dir,
                        f'''results/r_value_adv_1000_full_{file_name}_{args.n_samples}_{args.randseed}.h5'''))


if __name__ == "__main__":
    main(args)