import sys
sys.path.append('')
import os
import numpy as np
import torch
import torch.utils.data
import argparse
import matplotlib.pyplot as plt
import matplotlib.cm
import pickle

from tqdm import tqdm
from subspace_inference import models, losses, utils
from subspace_inference.posteriors.proj_model import SubspaceModel
from subspace_inference.posteriors.importance_sampler import ImportanceSampler
import subspace_inference.posteriors.subspaces_mod as Subspace
from evaluate_predictive import eval_predict
from evaluate_cifar import cifar_eval

def get_dataset(args, device):
    if args.task_type in [20, 21, 22, 23, 24]:
        # UCI small
        sys.path.append('./experiments/uci_exps')
        from bayesian_benchmarks.data import get_regression_data
        if args.task_type == 20:
            # UCI boston
            dataset = get_regression_data('boston', split=args.seed)
        elif args.task_type == 21:
            # UCI concrete
            dataset = get_regression_data('concrete', split=args.seed)
        elif args.task_type == 22:
            # UCI energy
            dataset = get_regression_data('energy', split=args.seed)
        elif args.task_type == 23:
            # UCI naval
            dataset = get_regression_data('naval', split=args.seed)
        elif args.task_type == 24:
            # UCI yacht
            dataset = get_regression_data('yacht', split=args.seed)

        # training parameters for UCI small
        args.epochs = 1000
        args.swag_start = 900
        args.lr_init = 1e-3
        args.swag_lr = 5e-4
        args.batch_size = 100
        args.wd = 1e-3
        args.momentum = 0.8
        args.criterion = losses.GaussianLikelihood(noise_var=None)
        args.inference_criterion = losses.GaussianLikelihood(noise_var=None)
        args.temperature = 10
        args.heatmap_x_size = 200
        args.heatmap_y_size = 200
        args.heatmap_x_range = 200
        args.heatmap_y_range = 400
        args.proposal_var = 50.0 * 50.0

        model_cfg = getattr(models, 'RegNet')
        model_cfg.kwargs['dimensions'] = [50]
        model_cfg.kwargs['input_dim'] = dataset.D
        model_cfg.kwargs['output_dim'] = 2
        train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(dataset.X_train.astype(np.float32)).to(device),
                                                       torch.from_numpy(dataset.Y_train.astype(np.float32)).to(device))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(dataset.X_test.astype(np.float32)).to(device),
                                                      torch.from_numpy(dataset.Y_test.astype(np.float32)).to(device))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
        model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
        args.full_dataset = dataset
        args.model_cfg = model_cfg
        args.prior_scale = 100.0
        return train_loader, test_loader, model.to(device)
    elif args.task_type in [30, 31, 32, 33, 34, 35]:
        # UCI Large
        sys.path.append('./experiments/uci_exps')
        from bayesian_benchmarks.data import get_regression_data
        if args.task_type == 30:
            # UCI Elevators
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 2000
            args.temperature = 1000
            dataset = get_regression_data('wilson_elevators', split=args.seed)
        elif args.task_type == 31:
            # UCI Protein
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 5000
            args.temperature = 1000
            dataset = get_regression_data('wilson_protein', split=args.seed)
        elif args.task_type == 32:
            # UCI pol
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 5000
            args.temperature = 1000
            dataset = get_regression_data('wilson_pol', split=args.seed)
        elif args.task_type == 33:
            # UCI keggd
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 10000
            args.temperature = 1000
            dataset = get_regression_data('wilson_keggdirected', split=args.seed)
        elif args.task_type == 34:
            # UCI keggu
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 10000
            args.temperature = 1000
            dataset = get_regression_data('wilson_keggundirected', split=args.seed)
        elif args.task_type == 35:
            # UCI skillcraft
            # training parameters
            args.swag_lr = 1e-4
            args.batch_size = 1000
            args.temperature = 100
            dataset = get_regression_data('wilson_skillcraft', split=args.seed)

        args.wd = 1e-3
        args.epochs = 1000
        args.swag_start = 900
        args.lr_init = 1e-3
        args.momentum = 0.8
        args.heatmap_x_size = 200
        args.heatmap_y_size = 200
        args.heatmap_x_range = 200
        args.heatmap_y_range = 400
        args.criterion = losses.GaussianLikelihood(noise_var=None)
        args.inference_criterion = losses.GaussianLikelihood(noise_var=None)
        args.proposal_var = 50.0 * 50.0
        model_cfg = getattr(models, 'RegNet')
        if dataset.N > 6000:
            model_cfg.kwargs['dimensions'] = [1000, 1000, 500, 50]
        else:
            model_cfg.kwargs['dimensions'] = [1000, 500, 50]
        model_cfg.kwargs['input_dim'] = dataset.D
        model_cfg.kwargs['output_dim'] = 2
        train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(dataset.X_train.astype(np.float32)).to(device),
                                                       torch.from_numpy(dataset.Y_train.astype(np.float32)).to(device))
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_dataset = torch.utils.data.TensorDataset(torch.from_numpy(dataset.X_test.astype(np.float32)).to(device),
                                                      torch.from_numpy(dataset.Y_test.astype(np.float32)).to(device))
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
        model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
        args.full_dataset = dataset
        args.model_cfg = model_cfg
        args.prior_scale = 100.0
        return train_loader, test_loader, model.to(device)
    elif args.task_type in [40, 41, 42, 43]:
        # CIFAR datasets
        from subspace_inference import data
        if args.task_type == 40:
            # CIFAR10 with VGG16
            args._dataset = 'CIFAR10'
            args._model = 'VGG16'
            # training parameters
            args.lr_init = 0.05
            args.swag_lr = 0.01
            args.wd = 5e-4
            args.heatmap_x_size = 50
            args.heatmap_y_size = 50
            args.heatmap_x_range = 20
            args.heatmap_y_range = 20
            args.proposal_var = 3.0 * 3.0
        elif args.task_type == 41:
            # CIFAR10 with PreResNet164
            args._dataset = 'CIFAR10'
            args._model = 'PreResNet164'
            args.lr_init = 0.1
            args.swag_lr = 0.01
            args.wd = 3e-4
            args.heatmap_x_size = 50
            args.heatmap_y_size = 50
            args.heatmap_x_range = 20
            args.heatmap_y_range = 20
            args.proposal_var = 3.0 * 3.0
        elif args.task_type == 42:
            # CIFAR100 with VGG16, num_params:
            args._dataset = 'CIFAR100'
            args._model = 'VGG16'
            # training parameters
            args.lr_init = 0.05
            args.swag_lr = 0.01
            args.wd = 5e-4
            args.heatmap_x_size = 50
            args.heatmap_y_size = 50
            args.heatmap_x_range = 10
            args.heatmap_y_range = 10
            args.proposal_var = 3.0 * 3.0
        elif args.task_type == 43:
            # CIFAR100 with PreResNet164
            args._dataset = 'CIFAR100'
            args._model = 'PreResNet164'
            # training parameters
            args.lr_init = 0.1
            args.swag_lr = 0.05
            args.wd = 3e-4
            args.heatmap_x_size = 50
            args.heatmap_y_size = 50
            args.heatmap_x_range = 5
            args.heatmap_y_range = 5
            args.proposal_var = 1.0 * 1.0

        args.epochs = 300
        args.swag_start = 160
        args.batch_size = 500
        args.momentum = 0.9
        args.criterion = losses.cross_entropy
        args.inference_criterion = losses.cross_entropy
        args.temperature = 1000
        model_cfg = getattr(models, args._model)
        args.num_workers = 16
        if args.load_ckpt:
            num_classes = 10 if args._dataset == 'CIFAR10' else 100
            loaders = {'train': None, 'test': None}
        else:
            loaders, num_classes = data.loaders(
                args._dataset,
                args.save_path,
                args.batch_size,
                args.num_workers,
                model_cfg.transform_train,
                model_cfg.transform_test,
                device=device,
                use_validation=True,
                split_classes=None
            )
        model_cfg.kwargs['num_classes'] = num_classes
        model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)
        args.model_cfg = model_cfg
        args.prior_scale = 100.0
        return loaders['train'], loaders['test'], model.to(device)
    else:
        raise ValueError("task_type not found")


def calc_posterior_heatmap(model_cfg, mean, cov_factor, criterion, heatmap_input, heatmap_loader, device,
                           proposal_var=1., temperature=1., prior_scale=1., enable_tqdm=True, bn_loader=None):
    subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
    sampler = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                loader=None, subspace=subspace, data=None, proposal_type="gaussian", deg_f=None,
                                device=device, prior_scale=prior_scale, *model_cfg.args, **model_cfg.kwargs)
    if bn_loader is not None:
        ipt = torch.tensor([0., 0.], device=device)
        w = subspace(ipt)
        offset = 0
        for param in sampler.base_model.parameters():
            param.data.copy_(w[offset:offset + param.numel()].view(param.size()).to(device))
            offset += param.numel()
        utils.bn_update(bn_loader, sampler.base_model, subset=1.0, device=device)
    log_weights = sampler.calc_marginal(heatmap_input, heatmap_loader, enable_tqdm=enable_tqdm)
    normalized_log_weights = log_weights + 0.5 * torch.logdet(cov_factor @ cov_factor.t())  # add the logdet of Jacobian
    # set nan to -Inf:
    normalized_log_weights[torch.isnan(normalized_log_weights)] = -float('inf')
    return normalized_log_weights


def calc_mean_loss_heatmap(model_cfg, mean, cov_factor, criterion, heatmap_input, heatmap_loader, device,
                           proposal_var=1., temperature=1., prior_scale=1., enable_tqdm=True, bn_loader=None):
    subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
    sampler = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                loader=None, subspace=subspace, data=None, proposal_type="gaussian", deg_f=None,
                                device=device, prior_scale=prior_scale, *model_cfg.args, **model_cfg.kwargs)
    if bn_loader is not None:
        ipt = torch.tensor([0., 0.], device=device)
        w = subspace(ipt)
        offset = 0
        for param in sampler.base_model.parameters():
            param.data.copy_(w[offset:offset + param.numel()].view(param.size()).to(device))
            offset += param.numel()
        utils.bn_update(bn_loader, sampler.base_model, subset=1.0, device=device)
    mean_loss = sampler.calc_mean_loss(heatmap_input, heatmap_loader, enable_tqdm=enable_tqdm)
    return mean_loss


def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    if args.gpu == 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.gpu - 1)) if os.name == 'posix' else torch.device("cuda:0")
    print("device: ", device)

    # create data loaders and model
    train_loader, test_loader, model = get_dataset(args, device)
    num_params = sum(param.numel() for param in model.parameters())
    print('Number of parameters: ' + str(num_params) + ', seed: ' + str(args.seed), ', task: ' + str(args.task_type))

    # parse subspace
    if args.enable_ft:
        subspace_ft = Subspace.CompleteDataSpace(num_params,
                                                 pca_rank=args.rank, total_traj_num=args.epochs - args.swag_start)
    if args.enable_tt:
        subspace_tt = Subspace.TrailingSpace(num_params,
                                             max_rank=args.M, pca_rank=args.rank)
    if args.enable_ba:
        subspace_ba = Subspace.ThinningBlockAveragingSpace(num_params,
                                                           max_rank=args.M, pca_rank=args.rank, total_traj_num=args.epochs - args.swag_start)

    if args.train == 1:
        # train model
        loss_hist = torch.zeros(args.epochs)
        loss_valid_hist = torch.zeros(args.epochs)
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum, weight_decay=args.wd)
        range_generator = tqdm(range(args.epochs)) if args.tqdm else range(args.epochs)
        model.train()
        for epoch in range_generator:
            t = (epoch + 1) / args.swag_start
            lr_ratio = args.swag_lr / args.lr_init
            if t <= 0.5:
                factor = 1.0
            elif t <= 0.9:
                factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
            else:
                factor = lr_ratio
            lr = factor * args.lr_init
            utils.adjust_learning_rate(optimizer, lr)
            train_res = utils.train_epoch(train_loader, model, args.criterion, optimizer, device=device, regression=True)
            valid_res = utils.eval_simp(test_loader, model, args.criterion, device=device)

            loss_hist[epoch] = train_res['loss']
            loss_valid_hist[epoch] = valid_res['loss']
            if epoch >= args.swag_start:
                # collect model
                model_params = utils.flatten([param.detach() for param in model.parameters()]).cpu()
                if args.enable_ft:
                    subspace_ft.collect_vector(model_params)
                if args.enable_tt:
                    subspace_tt.collect_vector(model_params)
                if args.enable_ba:
                    subspace_ba.collect_vector(model_params)
            if (epoch % (args.epochs // 5) == 0 or epoch == args.epochs - 1):
                print('Epoch %d. LR: %g. Loss: %.8f' % (epoch, lr, train_res['loss']))
        print('Training finished')

        # save trajectory
        if args.enable_ft:
            torch.save(subspace_ft.state_dict(), args.save_path + os.sep +
                       'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt')
        if args.enable_tt:
            torch.save(subspace_tt.state_dict(), args.save_path + os.sep +
                       'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt')
        if args.enable_ba:
            torch.save(subspace_ba.state_dict(), args.save_path + os.sep +
                       'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt')
        print('subspace model saved')

        # save train and test loader
        torch.save(train_loader.dataset, args.save_path + os.sep +
                   'dataloader_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_train.pt')
        torch.save(test_loader.dataset, args.save_path + os.sep +
                   'dataloader_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_test.pt')
        print('dataloader saved')

    else:
        # load trained subspace model and data loader
        if args.enable_ft:
            subspace_ft.load_state_dict(torch.load(args.save_path + os.sep +
                                                   'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt'))
        if args.enable_tt:
            subspace_tt.load_state_dict(torch.load(args.save_path + os.sep +
                                                   'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt'))
        if args.enable_ba:
            subspace_ba.load_state_dict(torch.load(args.save_path + os.sep +
                                                   'model_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt'))
        print('subspace model loaded')

    # get projection matrix and mean
    if args.eval_evidence == 1 or args.eval_cifar == 1:
        read_proj_mat = True
        if args.enable_ft:
            try:
                if not read_proj_mat:
                    raise FileNotFoundError
                proj_mat_ft = torch.load('project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt',
                                         map_location=device)
                mean_ft = torch.load('project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt', map_location=device)
                print('subspace mean and projection matrix loaded.')
            except FileNotFoundError:
                proj_mat_ft, mean_ft = subspace_ft.get_space()
                torch.save(proj_mat_ft, 'project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt')
                torch.save(mean_ft, 'project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ft.pt')
            space_model_ft = SubspaceModel(mean_ft.to(device), proj_mat_ft.to(device))
        if args.enable_tt:
            try:
                if not read_proj_mat:
                    raise FileNotFoundError
                proj_mat_tt = torch.load('project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt',
                                         map_location=device)
                mean_tt = torch.load('project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt', map_location=device)
                print('subspace mean and projection matrix loaded.')
            except FileNotFoundError:
                proj_mat_tt, mean_tt = subspace_tt.get_space()
                torch.save(proj_mat_tt, 'project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt')
                torch.save(mean_tt, 'project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_tt.pt')
            space_model_tt = SubspaceModel(mean_tt.to(device), proj_mat_tt.to(device))
        if args.enable_ba:
            try:
                if not read_proj_mat:
                    raise FileNotFoundError
                proj_mat_ba = torch.load('project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt',
                                         map_location=device)
                mean_ba = torch.load('project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt', map_location=device)
                print('subspace mean and projection matrix loaded.')
            except FileNotFoundError:
                proj_mat_ba, mean_ba = subspace_ba.get_space()
                torch.save(proj_mat_ba, 'project_mat' + os.sep + 'proj_mat_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt')
                torch.save(mean_ba, 'project_mat' + os.sep + 'mean_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_ba.pt')
            space_model_ba = SubspaceModel(mean_ba.to(device), proj_mat_ba.to(device))
        print('subspace mean and projection matrix obtained.')
        pass
        # return

    # calculate subspace angle
    calc_subspace_angle = True
    if calc_subspace_angle:
        from scipy.linalg import subspace_angles
        proj_mat_combined = [proj_mat_ft, proj_mat_tt, proj_mat_ba]
        result = np.zeros((3, 3, 2))
        for i in range(0, 3):
            for j in range(0, 3):
                result[i, j] = np.rad2deg(
                    subspace_angles(proj_mat_combined[i].t().numpy(), proj_mat_combined[j].t().numpy()))

        for i in range(0, 3):
            for j in range(i + 1, 3):
                print("i: %d, j: %d" % (i, j))
                print(
                    np.rad2deg(subspace_angles(proj_mat_combined[i].t().numpy(), proj_mat_combined[j].t().numpy())))

    # calculate Bayesian evidence
    if args.eval_evidence == 1:
        x_size = args.heatmap_x_size
        y_size = args.heatmap_y_size
        heatmap_x = torch.linspace(-args.heatmap_x_range, args.heatmap_x_range, x_size)
        heatmap_y = torch.linspace(-args.heatmap_y_range, args.heatmap_y_range, y_size)
        heatmap_X, heatmap_Y = torch.meshgrid(heatmap_x, heatmap_y)
        heatmap_input = torch.stack((heatmap_X.reshape(-1), heatmap_Y.reshape(-1)), dim=1).to(device)

        # calculate posterior heatmap on training set
        if True:
            heatmap_loader = train_loader
            stack_density = []
            if args.enable_ft:
                heatmap_posterior_density_ft = calc_posterior_heatmap(args.model_cfg, mean_ft, proj_mat_ft, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_ft.clone())
                torch.save(heatmap_posterior_density_ft, args.save_path + os.sep + 'heatmap_posterior_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ft.pt')
            else:
                stack_density.append(torch.full((x_size * y_size,), -float('inf')))
            if args.enable_tt:
                heatmap_posterior_density_tt = calc_posterior_heatmap(args.model_cfg, mean_tt, proj_mat_tt, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_tt.clone())
                torch.save(heatmap_posterior_density_tt, args.save_path + os.sep + 'heatmap_posterior_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_tt.pt')
            if args.enable_ba:
                heatmap_posterior_density_ba = calc_posterior_heatmap(args.model_cfg, mean_ba, proj_mat_ba, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_ba.clone())
                torch.save(heatmap_posterior_density_ba, args.save_path + os.sep + 'heatmap_posterior_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ba.pt')

        # plot combined figure on training set
        if True:
            # normalize density
            normalize_constant = torch.max(torch.stack(stack_density))
            for i in range(len(stack_density)):
                stack_density[i] = torch.exp(stack_density[i] - normalize_constant)
            max_val = torch.max(torch.stack(stack_density))
            min_val = torch.min(torch.stack(stack_density))
            if args.enable_ft:
                sum_density = torch.sum(stack_density[0])
                for i in range(len(stack_density)):
                    print(torch.sum(stack_density[i]) / sum_density)
            else:
                print(torch.sum(stack_density[1]) / torch.sum(stack_density[2]))
            fig, ax = plt.subplots(1, 3, figsize=(15, 3.55))
            cmap = matplotlib.colormaps['viridis']
            normalizer = matplotlib.colors.Normalize(min_val, max_val)
            im = matplotlib.cm.ScalarMappable(norm=normalizer, cmap=cmap)
            if args.enable_ft:
                ax[0].pcolormesh(heatmap_X, heatmap_Y, stack_density[0].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[0].set_title("Full subspace")

            if args.enable_tt:
                ax[1].pcolormesh(heatmap_X, heatmap_Y, stack_density[1].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[1].set_title("Tail subspace")

            if args.enable_ba:
                ax[2].pcolormesh(heatmap_X, heatmap_Y, stack_density[2].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[2].set_title("Block subspace")

            fig.colorbar(im, ax=ax.ravel().tolist())
            plt.savefig('figures/heatmap_posterior_task_' + str(args.task_type) + '_seed_' + str(args.seed) +
                        '_M=' + str(args.M) + '.png', dpi=300)
            plt.close()

        # calculate posterior heatmap on testing set
        if True:
            heatmap_loader = test_loader
            stack_density = []
            if args.enable_ft:
                heatmap_posterior_density_ft = calc_posterior_heatmap(args.model_cfg, mean_ft, proj_mat_ft, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_ft.clone())
                torch.save(heatmap_posterior_density_ft, args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ft.pt')
            else:
                stack_density.append(torch.full((x_size * y_size,), -float('inf')))
            if args.enable_tt:
                heatmap_posterior_density_tt = calc_posterior_heatmap(args.model_cfg, mean_tt, proj_mat_tt, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_tt.clone())
                torch.save(heatmap_posterior_density_tt, args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_tt.pt')
            if args.enable_ba:
                heatmap_posterior_density_ba = calc_posterior_heatmap(args.model_cfg, mean_ba, proj_mat_ba, args.inference_criterion,
                                                                      heatmap_input, heatmap_loader, device, temperature=args.temperature,
                                                                      prior_scale=args.prior_scale, bn_loader=train_loader).cpu()
                stack_density.append(heatmap_posterior_density_ba.clone())
                torch.save(heatmap_posterior_density_ba, args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(args.task_type) +
                           '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ba.pt')

        # plot combined figure on testing set
        if True:
            # normalize density
            normalize_constant = torch.max(torch.stack(stack_density))
            for i in range(len(stack_density)):
                stack_density[i] = torch.exp(stack_density[i] - normalize_constant)
            max_val = torch.max(torch.stack(stack_density))
            min_val = torch.min(torch.stack(stack_density))
            if args.enable_ft:
                sum_density = torch.sum(stack_density[0])
                for i in range(len(stack_density)):
                    print(torch.sum(stack_density[i]) / sum_density)
            else:
                print(torch.sum(stack_density[1]) / torch.sum(stack_density[2]))
            fig, ax = plt.subplots(1, 3, figsize=(15, 3.55))
            cmap = matplotlib.colormaps['viridis']
            normalizer = matplotlib.colors.Normalize(min_val, max_val)
            im = matplotlib.cm.ScalarMappable(norm=normalizer, cmap=cmap)
            if args.enable_ft:
                ax[0].pcolormesh(heatmap_X, heatmap_Y, stack_density[0].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[0].set_title("Full subspace")

            if args.enable_tt:
                ax[1].pcolormesh(heatmap_X, heatmap_Y, stack_density[1].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[1].set_title("Tail subspace")

            if args.enable_ba:
                ax[2].pcolormesh(heatmap_X, heatmap_Y, stack_density[2].reshape(x_size, y_size).numpy(), cmap=cmap,
                                 shading='auto', norm=normalizer)
            ax[2].set_title("Block subspace")

            fig.colorbar(im, ax=ax.ravel().tolist())
            plt.savefig('figures/heatmap_posterior_test_task_' + str(args.task_type) + '_seed_' + str(args.seed) +
                        '_M=' + str(args.M) + '.png', dpi=300)
            plt.close()

        pass

    # calcualte predictive
    if args.eval_predict == 1:
        # evaluate model log-lik, RMSE, and Calibration
        pkl_folder = "pkl_folder"
        detect_file = False

        def calc_and_save_pkl(space_model, train_loader, test_loader, device, args, filename):
            result_ess, result_qmc, result_nuts, result_vi = eval_predict(space_model, train_loader, test_loader, device, args)

            if args.calc_ess:
                with open(filename + '_ess.pkl', 'wb') as f:
                    pickle.dump(result_ess, f)
            if args.calc_qmc:
                with open(filename + '_qmc.pkl', 'wb') as f:
                    pickle.dump(result_qmc, f)
            if args.calc_nuts:
                with open(filename + '_nuts.pkl', 'wb') as f:
                    pickle.dump(result_nuts, f)
            if args.calc_vi:
                with open(filename + '_vi.pkl', 'wb') as f:
                    pickle.dump(result_vi, f)

        def detect_file_complete(filename):
            return os.path.exists(filename + '_ess.pkl') and os.path.exists(filename + '_qmc.pkl') and \
                os.path.exists(filename + '_nuts.pkl') and os.path.exists(filename + '_vi.pkl')

        if args.enable_ft:
            filename = pkl_folder + os.sep + 'result_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ft'
            if (not detect_file) or (not detect_file_complete(filename)):
                calc_and_save_pkl(space_model_ft, train_loader, test_loader, device, args, filename)
        if args.enable_tt:
            filename = pkl_folder + os.sep + 'result_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_tt'
            if (not detect_file) or (not detect_file_complete(filename)):
                calc_and_save_pkl(space_model_tt, train_loader, test_loader, device, args, filename)
        if args.enable_ba:
            filename = pkl_folder + os.sep + 'result_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ba'
            if (not detect_file) or (not detect_file_complete(filename)):
                calc_and_save_pkl(space_model_ba, train_loader, test_loader, device, args, filename)
        print('evaluation finished')

    # calc bayes factor and prior predictive
    if True:
        flag_calc_factor = True
        # read all result
        task_type_list = [
            20, 21, 22, 23, 24,
            # 30, 31, 32, 33, 34, 35,
            # 40, 41, 42, 43
        ]
        print_name_list = ['boston', 'concrete', 'energy', 'naval', 'yacht']
        file_str = 'uci_small'
        # print_name_list = ['elevators', 'protein', 'pol', 'keggD', 'keggU', 'skillcraft']
        # file_str = 'uci_large'
        # print_name_list = ['VGG_CF10', 'Res_CF10', 'VGG_CF100', 'Res_CF100']
        model_nums = 20

        def nanstd(x, dim=1):
            return torch.sqrt(torch.nanmean(torch.pow(x - torch.nanmean(x, dim=dim).unsqueeze(dim), 2), dim=dim))

        if flag_calc_factor:
            # 1. Bayes Factors: tt/ba; tt/ft; ba/ft
            result_bayes_factor = torch.zeros(len(task_type_list), model_nums, 3)  # (task, seed, 3)
            # read heatmap
            for i, task in enumerate(task_type_list):
                for j, iter in enumerate(range(model_nums)):
                    heatmap_tt = torch.load(
                        args.save_path + os.sep + 'heatmap_posterior_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(args.M) + '_tt.pt')
                    heatmap_ba = torch.load(
                        args.save_path + os.sep + 'heatmap_posterior_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(args.M) + '_ba.pt')
                    normalize_constant = torch.max(heatmap_ba)
                    heatmap_tt_sum = torch.sum(torch.exp(heatmap_tt - normalize_constant))
                    heatmap_ba_sum = torch.sum(torch.exp(heatmap_ba - normalize_constant))
                    result_bayes_factor[i, j, 0] = heatmap_tt_sum / heatmap_ba_sum
                    if args.enable_ft:
                        heatmap_ft = torch.load(
                            args.save_path + os.sep + 'heatmap_posterior_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(args.M) + '_ft.pt')
                        heatmap_ft_sum = torch.sum(torch.exp(heatmap_ft - normalize_constant))
                        result_bayes_factor[i, j, 1] = heatmap_tt_sum / heatmap_ft_sum
                        result_bayes_factor[i, j, 2] = heatmap_ba_sum / heatmap_ft_sum
            print("calc bayes factor success.")
            mean_val_bf = torch.nanmean(result_bayes_factor, dim=1)
            sd_val_bf = nanstd(result_bayes_factor, dim=1)

            # 2. prior predictive: tt/ba; tt/ft; ba/ft
            result_bayes_prior_pred = torch.zeros(len(task_type_list), model_nums, 3)  # (task, seed, 3)
            # read heatmap
            for i, task in enumerate(task_type_list):
                for j, iter in enumerate(range(model_nums)):
                    heatmap_tt = torch.load(
                        args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(args.M) + '_tt.pt')
                    heatmap_ba = torch.load(
                        args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(args.M) + '_ba.pt')
                    normalize_constant = torch.max(heatmap_ba)
                    heatmap_tt_sum = torch.sum(torch.exp(heatmap_tt - normalize_constant))
                    heatmap_ba_sum = torch.sum(torch.exp(heatmap_ba - normalize_constant))
                    result_bayes_prior_pred[i, j, 0] = heatmap_tt_sum / heatmap_ba_sum
                    if args.enable_ft:
                        heatmap_ft = torch.load(
                            args.save_path + os.sep + 'heatmap_posterior_test_task_' + str(task) + '_seed_' + str(10000 + iter) + '_M=' + str(
                                args.M) + '_ft.pt')
                        heatmap_ft_sum = torch.sum(torch.exp(heatmap_ft - normalize_constant))
                        result_bayes_prior_pred[i, j, 1] = heatmap_tt_sum / heatmap_ft_sum
                        result_bayes_prior_pred[i, j, 2] = heatmap_ba_sum / heatmap_ft_sum
            print("calc prior pred success.")
            mean_val_pp = torch.nanmean(result_bayes_prior_pred, dim=1)
            sd_val_pp = nanstd(result_bayes_prior_pred, dim=1)

            # print result
            print_str = "Bayes Factor:"
            print_str_2 = "Evidence Ratio:"
            for i, task in enumerate(task_type_list):
                print_str += "& %.3f $\pm$ %.3f " % (mean_val_bf[i, 0].item(), sd_val_bf[i, 0].item())
                print_str_2 += "& %.3f $\pm$ %.3f " % (mean_val_pp[i, 0].item(), sd_val_pp[i, 0].item())
            print(print_str)
            print(print_str_2)
            pass

    if args.eval_cifar == 1:
        # read pkl
        # cifar_eval(args, device, space_model_tt, test_loader, train_loader)
        pkl_folder = "pkl_folder"
        if args.enable_tt:
            filename = pkl_folder + os.sep + 'result_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_tt'
            if args.calc_ess:
                pkl_name = filename + '_ess'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_tt, test_loader, train_loader, post_samples, pkl_name)
            if args.calc_qmc:
                pkl_name = filename + '_qmc'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples_org = result['samples']
                weights_org = result['weights']
                sample_size = 30
                # select top 30 weights samples
                select_idx = torch.argsort(torch.from_numpy(weights_org), descending=True)[:sample_size]
                post_samples = post_samples_org[select_idx]
                weights = weights_org[select_idx] / np.sum(weights_org[select_idx])
                cifar_eval(args, device, space_model_tt, test_loader, train_loader, post_samples, pkl_name, weights)
            if args.calc_nuts:
                pkl_name = filename + '_nuts'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_tt, test_loader, train_loader, post_samples, pkl_name)
            if args.calc_vi:
                pkl_name = filename + '_vi'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_tt, test_loader, train_loader, post_samples, pkl_name)
        if args.enable_ba:
            filename = pkl_folder + os.sep + 'result_task_' + str(args.task_type) + '_seed_' + str(args.seed) + '_M=' + str(args.M) + '_ba'
            if args.calc_ess:
                pkl_name = filename + '_ess'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_ba, test_loader, train_loader, post_samples, pkl_name)
            if args.calc_qmc:
                pkl_name = filename + '_qmc'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples_org = result['samples']
                weights_org = result['weights']
                sample_size = 30
                # select top 30 weights samples
                select_idx = torch.argsort(torch.from_numpy(weights_org), descending=True)[:sample_size]
                post_samples = post_samples_org[select_idx]
                weights = weights_org[select_idx] / np.sum(weights_org[select_idx])
                cifar_eval(args, device, space_model_ba, test_loader, train_loader, post_samples, pkl_name, weights)
            if args.calc_nuts:
                pkl_name = filename + '_nuts'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_ba, test_loader, train_loader, post_samples, pkl_name)
            if args.calc_vi:
                pkl_name = filename + '_vi'
                with open(pkl_name + '.pkl', 'rb') as f:
                    result = pickle.load(f)
                post_samples = result['samples']
                cifar_eval(args, device, space_model_ba, test_loader, train_loader, post_samples, pkl_name)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=1, help='gpu_available')  # 0: cpu; 1: cuda:0, 2: cuda:1, ...
    parser.add_argument('--M', type=int, default=5, help='M value')  # default: 20
    parser.add_argument('--rank', type=int, default=2, help='max_rank')  # default: 2
    # task_type: 20: UCI small - boston
    # task_type: 21: UCI small - concrete
    # task_type: 22: UCI small - energy
    # task_type: 23: UCI small - naval
    # task_type: 24: UCI small - yacht

    # task_type: 30: UCI large - elevators
    # task_type: 31: UCI large - protein
    # task_type: 32: UCI large - pol
    # task_type: 33: UCI large - keggdirected
    # task_type: 34: UCI large - keggundirected
    # task_type: 35: UCI large - skillcraft

    # task_type: 40: CIFAR10 with VGG16
    # task_type: 41: CIFAR10 with PreResNet164
    # task_type: 42: CIFAR100 with VGG16
    # task_type: 43: CIFAR100 with PreResNet164
    parser.add_argument('--task_type', type=int, default=20, help='dataset')
    parser.add_argument('--enable_ft', type=int, default=0, help='enable full trajectory subspace inference')
    parser.add_argument('--enable_tt', type=int, default=1, help='enable tail trajectory subspace inference')
    parser.add_argument('--enable_ba', type=int, default=1, help='enable block-averaging subspace inference')
    parser.add_argument('--train', type=int, default=1, help='train model and save subspace')
    parser.add_argument('--eval_evidence', type=int, default=1, help='evaluate model evidence')
    parser.add_argument('--eval_predict', type=int, default=1, help='evaluate model log-lik, RMSE, and Calibration')
    parser.add_argument('--eval_cifar', type=int, default=1, help='evaluate CIFAR model')
    parser.add_argument('--save_path', type=str, default='ckpts', help='model save path')
    parser.add_argument('--calc_ess', type=int, default=1, help='enable ESS inference')
    parser.add_argument('--calc_qmc', type=int, default=1, help='enable QMC inference')
    parser.add_argument('--calc_nuts', type=int, default=1, help='enable NUTS inference')
    parser.add_argument('--calc_vi', type=int, default=1, help='enable VI inference')
    parser.add_argument('--svhn', type=int, default=1, help='use SVHN dataset for OOD detection')
    parser.add_argument('--cifar-c', type=int, default=1, help='use CIFAR-C dataset to calculate ACC and ECE')
    parser.add_argument('--nll', type=int, default=1, help='test NLL on CIFAR')
    parser.add_argument('--load_ckpt', type=int, default=0, help='load dataset from checkpoint')
    parser.add_argument('--tqdm', type=int, default=1, help='enable tqdm')
    parser.add_argument('--dbg', type=int, default=1, help='debug flag')
    parser.add_argument('--seed', type=int, default=10000, help='seed')
    args = parser.parse_args()

    # run for on times
    main(args)

    # run for 20 times with seed 10000, 10001, ..., 10019
    # for iter in range(20):
    #     args.seed = 10000 + iter
    #     main(args)
