import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import argparse
import random
from config import get_config
from networks import get_network
from datasets import load_data

parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--stoch_varianz', default=0.05, type=float, help='Only applicable for SNN models - variance of noise added to the input')
parser.add_argument('--device', default=0, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=100, type=int, help='Amount of samples used during inference')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--smooth', type=bool, default=False) # <- fixed
args = parser.parse_args()
args = get_config(args)
torch.cuda.set_device(args.device)


def main(args):
    # get data
    _, _, red_test_loader = load_data(args.dataset, args.batch_size, args.root_dir)

    # get model
    model = get_network(args)

    # load model
    if args.net == 'SNN':
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.layer}_{args.stoch_varianz}.bin'''))
    else:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}.bin'''))
    model.load_state_dict(parameter, strict=False)

    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # different seed than during training
    args.randseed += 1000
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    grad_for_variances = get_grad_for_variances(red_test_loader, model)
    np.save(Path(args.root_dir,
                 f'''results/CLT_GradVar_{args.dataset}_{args.net}_{args.droprate}_{args.layer}_{args.stoch_varianz}_{args.randseed}.npy'''),
            np.stack(grad_for_variances, 0))


def get_grad_for_variances(test_loader, net):
    gradients = []
    for idx in range(1000):
        if idx % 10 == 0:
            print(idx)
        X_mb, t_mb = test_loader.dataset.list_IDs.data[idx].unsqueeze(0), test_loader.dataset.labels.data[idx]
        grad = calculate_grads(X_mb, t_mb, net)
        gradients.append(grad)
    return variances


def predict_model(model, test_data):
    pred = model.forward(test_data)
    if args.net != 'SNN':
        pred = F.softmax(pred, dim=1)
    return pred


def get_loss(pred, y, x):
    idx_all = torch.arange(10).view(1, 10).expand_as(pred).cuda()
    mask = y.view(1, 1).cuda()
    mask = mask.expand_as(idx_all) != idx_all
    rest = pred[mask].view(x.shape[0], 9)
    second_prob, _ = torch.max(rest, 1)
    loss = pred[0, y] - second_prob
    return loss


def calculate_grads(input, label, model):
    model.cuda()
    gradients = []
    for i in range(1000):
        model.eval()
        input = input.cuda().requires_grad_(True)
        pred = predict_model(model, input)
        loss = get_loss(pred, label, input)
        loss.backward()
        grad_value = input.grad.detach().cpu()
        gradients.append(grad_value.squeeze(0))
        input.grad = torch.zeros_like(grad_value, device='cuda')
    grads = np.stack(gradients, 0)
    return grads


if __name__ == "__main__":
    main(args)
