import copy
import os
import pickle
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm

from subspace_inference.posteriors.importance_sampler import ImportanceSampler
from subspace_inference import models, losses, utils


def cifar_eval(args, device, space_model, test_loader, train_loader, proposal_sample, pkl_name, weights=None):
    test_file = True
    if test_file:
        if args.task_type in [40, 41]:  # for CIFAR-10
            if os.path.exists(pkl_name + '_auc.pkl') and os.path.exists(pkl_name + '_nll.pkl') and os.path.exists(pkl_name + '_acc.pkl') and os.path.exists(
                    pkl_name + '_ece_m.pkl') and os.path.exists(pkl_name + '_acc_m.pkl'):
                print(pkl_name + " exists.")
                return
        if args.task_type in [42, 43]:  # for CIFAR-100
            if os.path.exists(pkl_name + '_auc.pkl') and os.path.exists(pkl_name + '_nll.pkl') and os.path.exists(pkl_name + '_acc.pkl'):
                print(pkl_name + " exists.")
                return
    sampler = ImportanceSampler(base=args.model_cfg.base, criterion=args.inference_criterion, proposal_var=args.proposal_var,
                                temperature=args.temperature,
                                loader=train_loader, subspace=space_model, data=None, proposal_type="gaussian", deg_f=None,
                                device=device, prior_scale=args.prior_scale, *args.model_cfg.args, **args.model_cfg.kwargs)
    criterion = torch.nn.CrossEntropyLoss()
    sample_size = proposal_sample.shape[0]
    all_models = []
    if weights is None:
        weights = torch.ones(sample_size, device=device) / sample_size

    for i in tqdm(range(sample_size)):
        # batch normalization layer update
        current_sample = proposal_sample[i]
        w = space_model(torch.from_numpy(current_sample).type(torch.float32).to(device))
        offset = 0
        for param in sampler.base_model.parameters():
            param.data.copy_(w[offset:offset + param.numel()].view(param.size()).to(device))
            offset += param.numel()
        utils.bn_update(train_loader, sampler.base_model, subset=1.0, device=device)
        sampler.base_model.eval()
        # copy model parameters for all layer (including BN layer)
        all_models.append(copy.deepcopy(sampler.base_model))
    print("BN layer update finished.")

    if args.svhn == 1:
        from sklearn.metrics import roc_auc_score
        svhn_test_loader = torch.load(args.save_path + os.sep + 'svhn' + os.sep + 'svhn_test_loader.pt')
        pred_list = []
        pred_list_svhn = []
        with torch.no_grad():
            for i in tqdm(range(sample_size)):
                model = all_models[i]

                # for CIFAR test_loader
                pred_list_inner = []
                for batch_num, (data, target) in enumerate(test_loader):
                    data = data.to(device)
                    output = model(data)
                    pred_list_inner.append(F.softmax(output, dim=1))
                pred_list_inner = torch.cat(pred_list_inner, 0)
                pred_list.append(pred_list_inner)

                # for SVHN test_loader
                pred_list_svhn_inner = []
                for batch_num, (data, target) in enumerate(svhn_test_loader):
                    data = data.to(device)
                    output = model(data)
                    pred_list_svhn_inner.append(F.softmax(output, dim=1))
                pred_list_svhn_inner = torch.cat(pred_list_svhn_inner, 0)
                pred_list_svhn.append(pred_list_svhn_inner)

            # average models based on weights
            fake = sum([pred_i * weights[i] for i, pred_i in enumerate(pred_list)])
            fake_svhn = sum([pred_i * weights[i] for i, pred_i in enumerate(pred_list_svhn)])
            values_test, pred_label = torch.max(fake, dim=1)
            values_test_svhn, pred_label_svhn = torch.max(fake_svhn, dim=1)
            num_in = len(values_test)
            num_out = len(values_test_svhn)
            y_true = np.concatenate([np.zeros(num_in), np.ones(num_out)])
            y_score = np.concatenate([values_test.detach().cpu().numpy(), values_test_svhn.detach().cpu().numpy()])
        auc = roc_auc_score(y_true, 1 - y_score)
        print("Weighted AUC: ", auc)
        # save to pkl
        with open(pkl_name + '_auc.pkl', 'wb') as f:
            pickle.dump(auc, f)

    if args.nll == 1:
        nll_list = []
        correct_predictions = 0
        total_samples = 0
        with torch.no_grad():
            for batch_num, (data, target) in enumerate(test_loader):
                data = data.to(device)
                target = target.to(device)
                pred_list = []
                for i in range(sample_size):
                    model = all_models[i]
                    output = model(data)
                    pred_list.append(F.softmax(output, dim=1))
                avg_pred = sum(pred_list) / len(pred_list)
                pred_classes = torch.argmax(avg_pred, dim=1)
                correct_predictions += (pred_classes == target).sum().item()
                loss = criterion(torch.log(avg_pred), target)
                nll_list.append(loss * data.shape[0])
                total_samples += data.shape[0]
        nll_value = (sum(nll_list) / total_samples).cpu().numpy()
        accuracy = correct_predictions / total_samples
        print("NLL: ", nll_value)
        print("Accuracy: ", accuracy)
        # save to pkl
        with open(pkl_name + '_nll.pkl', 'wb') as f:
            pickle.dump(nll_value, f)
        with open(pkl_name + '_acc.pkl', 'wb') as f:
            pickle.dump(accuracy, f)

    if args.cifar_c == 1 and args.task_type in [40, 41]:
        corruption = ['gaussian_noise', 'motion_blur', 'fog']
        acc_m = np.zeros((5, len(corruption)))
        ece_m = np.zeros((5, len(corruption)))
        test_set = copy.deepcopy(test_loader.dataset)
        cifar_c_data = {c: np.load(args.save_path + os.sep + "cifar-c" + os.sep + "%s.npy" % c) for c in corruption}
        cifar_c_labels = np.load(args.save_path + os.sep + "cifar-c" + os.sep + "labels.npy")

        def get_accuracy(truth, pred):
            assert len(truth) == len(pred)
            right = 0
            for i in range(len(truth)):
                if truth[i] == pred[i]:
                    right += 1.0
            return right / len(truth)

        for j in range(5):
            test_set.targets = cifar_c_labels[j * 10000:(j + 1) * 10000]
            for i in range(len(corruption)):
                test_set.data = cifar_c_data[corruption[i]][j * 10000:(j + 1) * 10000]
                testloader = torch.utils.data.DataLoader(
                    test_set,
                    batch_size=1000,
                    shuffle=False,
                    num_workers=8,
                    pin_memory=True,
                    persistent_workers=True
                )
                pred_list = []
                for k in range(sample_size):
                    model = all_models[k]
                    pred_list_inner = []
                    truth_res = []
                    test_loss = 0
                    correct = 0
                    total = 0
                    with torch.no_grad():
                        for batch_num, (data, target) in enumerate(testloader):
                            truth_res += list(target.data)
                            data = data.to(device, non_blocking=True)
                            target = target.to(device, non_blocking=True)
                            output = model(data)
                            pred_list_inner.append(F.softmax(output, dim=1))
                            loss = criterion(output, target)
                            test_loss += loss.data.item()
                            _, predicted = torch.max(output.data, 1)
                            total += target.size(0)
                            correct += predicted.eq(target.data).sum().cpu()
                    pred_list_inner = torch.cat(pred_list_inner, 0)
                    pred_list.append(pred_list_inner)
                fake = sum([pred_i * weights[i] for i, pred_i in enumerate(pred_list)])
                values, pred_label = torch.max(fake, dim=1)
                pred_res = list(pred_label.data)
                acc = get_accuracy(truth_res, pred_res)
                truth_res = torch.as_tensor(truth_res)
                acc_m[j, i] = acc
            print(j, np.mean(acc_m[j, :]), np.mean(ece_m[j, :]))
        print('acc', np.mean(acc_m, axis=1))
        # save to pkl
        with open(pkl_name + '_acc_m.pkl', 'wb') as f:
            pickle.dump(acc_m, f)

    # delete all models
    for model_item in all_models:
        del model_item
    del all_models

    # clear cuda cache
    with torch.cuda.device(device):
        torch.cuda.empty_cache()
