import torch
import numpy as np
from pathlib import Path
import argparse
import random
import torch.nn as nn
import torch.nn.functional as F
from config import get_config
from networks import get_network
from datasets import load_data

parser = argparse.ArgumentParser()
parser.add_argument('--net', type=str, default='ResNet', choices=['SNN', 'IM', 'BNN', 'ResNet'], help='model used')
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['FashionMNIST', 'CIFAR10'], help='dataset used')
parser.add_argument('--stoch_varianz', default=0.05, type=float, help='Only applicable for SNN models - variance of noise added to the input')
parser.add_argument('--device', default=0, type=int, help='If you have more than one gpu, select the one on which the code is run')
parser.add_argument('--n_samples', default=100, type=int, help='Amount of samples used during inference')
parser.add_argument('--droprate', type=float, default=0.6, help='Only applicable for ResNet, specifies the dropout probability')
parser.add_argument('--smooth', type=bool, default=False) # <- fixed
args = parser.parse_args()
args = get_config(args)
torch.cuda.set_device(args.device)


def main(args):
    # get data
    _, _, red_test_loader = load_data(args.dataset, args.batch_size, args.root_dir)

    # get model
    model = get_network(args)

    # load model
    if args.net == 'SNN':
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.layer}_{args.stoch_varianz}.bin'''))
    else:
        parameter = torch.load(Path(args.root_dir,
                                    f'''models/{args.dataset}/model_{args.net}_{args.dataset}_{args.epochs}_{args.randseed}_{args.droprate}.bin'''))
    model.load_state_dict(parameter)

    if args.net == 'ResNet':
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    # different seed than during training
    args.randseed += 1000
    np.random.seed(args.randseed)
    torch.manual_seed(args.randseed)
    random.seed(args.randseed)
    variances = get_variances(red_test_loader, model)
    np.save(Path(args.root_dir,
                 f'''results/CLT_Var_{args.dataset}_{args.net}_{args.droprate}_{args.layer}_{args.stoch_varianz}_{args.randseed}.npy'''),
            variances)


def get_variances(test_loader, net):
    variances = []
    for idx in range(1000):
        if idx % 10 == 0:
            print(idx)
        X_mb, t_mb = test_loader.dataset.list_IDs.data[idx].unsqueeze(0), test_loader.dataset.labels.data[idx]
        var = calculate_variance(X_mb, t_mb, net)
        variances.append(var)
    return variances


def predict_model(model, test_data):
    pred = model.forward(test_data)
    if args.net != 'SNN':
        pred = F.softmax(pred, dim=1)
    return pred.cpu().data.numpy()


def calculate_variance(input, label, model):
    model.cuda()
    preds = []
    for i in range(1000):
        model.eval()
        input = input.cuda()
        pred = predict_model(model, input)
        preds.append(pred)
    pred = np.concatenate(preds, 0)
    var_preds = np.var(pred, axis=0)
    var_predicted = var_preds[label]
    return var_predicted


if __name__ == "__main__":
    main(args)
