# Test for RQMC-IS in subspace inference

import sys
import os

import matplotlib.cm

sys.path.append('')
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import argparse

from tqdm import tqdm
from subspace_inference import models, losses, utils
from subspace_inference.posteriors.proj_model import SubspaceModel, ProjectedModel
from subspace_inference.posteriors.ess import EllipticalSliceSampling
from subspace_inference.posteriors.importance_sampler import ImportanceSampler
from visualization import plot_predictive, plot_predictive_with_weight

import subspace_inference.posteriors.subspaces_mod as Subspace


def main(args):
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.gpu == 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.gpu - 1)) if os.name == 'posix' else torch.device("cuda:0")
    print(device)
    features = lambda x: np.hstack([x[:, None] / 2.0, (x[:, None] / 2.0) ** 2])
    data = np.load("ckpts/data.npy")
    x, y = data[:, 0], data[:, 1]
    y = y[:, None]
    f = features(x)
    z = np.linspace(-10, 10, 200)
    inp = torch.from_numpy(features(z).astype(np.float32)).to(device)
    dataset = torch.utils.data.TensorDataset(torch.from_numpy(f.astype(np.float32)).to(device),
                                             torch.from_numpy(y.astype(np.float32)).to(device))
    test_ratio = 0.10
    test_size = int(len(dataset) * test_ratio)
    train_size = len(dataset) - test_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    save_dataset = False
    if save_dataset:
        torch.save(train_dataset, 'ckpts/train_dataset.pt')
        torch.save(test_dataset, 'ckpts/test_dataset.pt')
    else:
        train_dataset = torch.load('ckpts/train_dataset.pt', map_location=device)
        test_dataset = torch.load('ckpts/test_dataset.pt', map_location=device)
    loader = torch.utils.data.DataLoader(train_dataset, batch_size=500, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(test_dataset, batch_size=500, shuffle=False)
    model_cfg = models.ToyRegNet
    model = model_cfg.base(*model_cfg.args, **model_cfg.kwargs)

    # load all trajectory
    if args.data == 0:
        full_traj = torch.load('ckpts/model.pt')
        raise ValueError("stop")
    else:
        assert args.data in [i for i in range(1, 21)]
        full_traj = torch.load('ckpts/linear2_traj_' + str(args.data) + '.pt')
    print("load traj %d." % args.data)
    traj_sample_size = full_traj.shape[0]
    num_params = full_traj.shape[1]
    M = args.M
    print("M value: %d" % M)
    # draw posterior predictive
    draw_pp = True
    do_ess_sampling = True
    draw_heatmap_posterior = True
    draw_heatmap_likelihood = True
    draw_post_int = True
    calc_subspace_angle = True
    do_method_FT = True
    do_method_TT = True
    do_method_BA = True
    proposal_var = 30.0
    print("proposal var: %.4f" % proposal_var)
    pca_rank = 2
    num_samples = 64
    criterion = losses.GaussianLikelihood(noise_var=0.05)
    temperature = 1.5
    x_size = 500
    y_size = 500
    heatmap_x = torch.linspace(-50, 50, x_size)
    heatmap_y = torch.linspace(-200, 200, y_size)
    heatmap_X, heatmap_Y = torch.meshgrid(heatmap_x, heatmap_y)
    heatmap_input = torch.stack((heatmap_X.reshape(-1), heatmap_Y.reshape(-1)), dim=1)

    def draw_posterior_predictive(mean, cov_factor, title, proposal_var=20.0):
        subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
        proposal_var = 20.0 * 20.0
        Sampler = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                    loader=loader, subspace=subspace, data=data, proposal_type="gaussian", deg_f=6, device=device,
                                    *model_cfg.args, **model_cfg.kwargs)

        # ESS
        if True:
            ess_sample_size = 100
            print("draw posterior predictive with ESS.")
            ess_model = EllipticalSliceSampling(
                base=model_cfg.base,
                subspace=subspace,
                var=None,
                loader=loader,
                criterion=criterion,
                num_samples=1000,
                use_cuda=True,
                device=device,
                *model_cfg.args,
                **model_cfg.kwargs
            )
            ess_model.fit(temperature=temperature, scale=proposal_var)
            ess_samples = torch.from_numpy(ess_model.all_samples).t()[-ess_sample_size:].type(torch.FloatTensor).to(device)
            trajectories = torch.zeros((ess_sample_size, inp.shape[0]))
            for i in range(ess_sample_size):
                trajectories[i, :] = Sampler.gen_data(inp, ess_samples[i]).squeeze()

            # plot posterior predictive
            plot_data_ess = plot_predictive(data, trajectories.numpy(), z, title=title + ", ESS")
            ax = plt.gca()
            ax.set_xlim([-10, 10])
            ax.set_ylim([-0.75, 1.50])
            plt.savefig("figures/predictive/predictive_" + str(args.data) + title + "_ESS.png", dpi=300)
            plt.savefig("figures/predictive/predictive_" + str(args.data) + title + "_ESS.pdf")
            plt.close()
            torch.save(plot_data_ess, 'ckpts/plot_data_ess' + title + '.pt')

        # QMC-IS
        if True:
            qmc_sample_size = 1024
            print("draw posterior predictive with QMC-IS.")
            proposal_sample, weights = Sampler.sampling_with_weights(qmc_sample_size, enable_qmc=True, enable_tqdm=True)
            trajectories_all = torch.zeros((qmc_sample_size, inp.shape[0]))
            for i in range(qmc_sample_size):
                trajectories_all[i, :] = Sampler.gen_data(inp, proposal_sample[i]).squeeze()
            plot_data_qmc = plot_predictive_with_weight(data, trajectories_all.numpy(), z, weights, title=title + ", RQMC-IS")
            ax = plt.gca()
            ax.set_xlim([-10, 10])
            ax.set_ylim([-0.75, 1.50])
            plt.savefig("figures/predictive/predictive_" + str(args.data) + title + "_RQMC-IS.png", dpi=300)
            plt.savefig("figures/predictive/predictive_" + str(args.data) + title + "_RQMC-IS.pdf")
            plt.close()
            torch.save(plot_data_qmc, 'ckpts/plot_data_qmc' + title + '.pt')
        pass

    def calc_posterior_heatmap(mean, cov_factor, heatmap_input, heatmap_loader, enable_tqdm=True):
        subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
        Sampler_ = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                     loader=loader, subspace=subspace, data=data, proposal_type="gaussian", deg_f=6, device=device,
                                     *model_cfg.args, **model_cfg.kwargs)
        log_weights = Sampler_.calc_marginal(heatmap_input.to(device), heatmap_loader, enable_tqdm=True)
        return log_weights

    def calc_likelihood_heatmap(mean, cov_factor, heatmap_input, heatmap_loader, enable_tqdm=True):
        subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
        Sampler_ = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                     loader=loader, subspace=subspace, data=data, proposal_type="gaussian", deg_f=6, device=device,
                                     *model_cfg.args, **model_cfg.kwargs)
        log_weights = Sampler_.calc_likelihood(heatmap_input.to(device), heatmap_loader, enable_tqdm=True)
        return log_weights

    def calc_marginal_is(mean, cov_factor, heatmap_loader, num_sample=50000, enable_tqdm=True):
        subspace = SubspaceModel(mean.to(device), cov_factor.to(device))
        Sampler = ImportanceSampler(base=model_cfg.base, criterion=criterion, proposal_var=proposal_var, temperature=temperature,
                                    loader=loader, subspace=subspace, data=data, proposal_type="gaussian", deg_f=6, device=device,
                                    *model_cfg.args, **model_cfg.kwargs)
        proposal = torch.distributions.MultivariateNormal(Sampler.prior.loc, Sampler.prior.covariance_matrix / 5.0)
        samples = proposal.sample((num_sample,)).to(device)
        proposal_weight = proposal.log_prob(samples)
        log_weights = Sampler.calc_marginal(samples, heatmap_loader, enable_tqdm=enable_tqdm)
        combined_weights = log_weights - proposal_weight
        return combined_weights

    def parse_subspaces(subspace_flag, num_params, max_rank=20, pca_rank=2, total_traj_num=1000):
        if subspace_flag == 'FT':
            subspace = Subspace.CompleteDataSpace(num_params, pca_rank=pca_rank, total_traj_num=total_traj_num)
        elif subspace_flag == 'TT':
            subspace = Subspace.TrailingSpace(num_params, max_rank=max_rank, pca_rank=pca_rank)
        elif subspace_flag == 'BA':
            subspace = Subspace.ThinningBlockAveragingSpace(num_params, max_rank=max_rank, pca_rank=pca_rank, centering=False, total_traj_num=total_traj_num)
        return subspace

    if do_method_FT:
        # mean: mean of full SWA trajectory
        # cov_factor: covariance matrix of full SWA trajectory.
        subspace_ = parse_subspaces('FT', num_params, max_rank=M, pca_rank=pca_rank, total_traj_num=traj_sample_size)
        for i in range(traj_sample_size):
            subspace_.collect_vector(full_traj[i])
        proj_mat1, mean = subspace_.get_space()
        s1 = subspace_.singular_values
        proj_samples1 = (torch.diag(torch.from_numpy((traj_sample_size - 1) / s1 ** 2)) @ proj_mat1 @ (full_traj - mean).t()).t()
        if draw_pp:
            draw_posterior_predictive(mean, proj_mat1, title="Full trajectory subspace")
        if do_ess_sampling:
            ess_sampling(mean, proj_mat1, num_samples, proposal_var)
        if draw_heatmap_posterior:
            heatmap_posterior_density1 = calc_posterior_heatmap(mean, proj_mat1, heatmap_input, heatmap_loader)
        if draw_heatmap_likelihood:
            heatmap_likelihood_density1 = calc_likelihood_heatmap(mean, proj_mat1, heatmap_input, heatmap_loader)
        if draw_post_int:
            log_weights1 = calc_marginal_is(mean, proj_mat1, heatmap_loader)

    # 2: subspace constructed using mean (from full SWA trajectory) and covariance/deviation from last $M$ points
    if do_method_TT:
        subspace_ = parse_subspaces('TT', num_params, max_rank=M, pca_rank=pca_rank, total_traj_num=traj_sample_size)
        for i in range(traj_sample_size):
            subspace_.collect_vector(full_traj[i])
        proj_mat2, mean = subspace_.get_space()
        s2 = subspace_.singular_values
        proj_samples2 = (torch.diag(torch.from_numpy((M - 1) / s2 ** 2)) @ proj_mat2 @ (full_traj - mean).t()).t()
        if draw_pp:
            draw_posterior_predictive(mean, proj_mat2, title="Tail trajectory subspace")
        if do_ess_sampling:
            ess_sampling(mean, proj_mat2, num_samples, proposal_var)
        if draw_heatmap_posterior:
            heatmap_posterior_density2 = calc_posterior_heatmap(mean, proj_mat2, heatmap_input, heatmap_loader)
        if draw_heatmap_likelihood:
            heatmap_likelihood_density2 = calc_likelihood_heatmap(mean, proj_mat2, heatmap_input, heatmap_loader)
        if draw_post_int:
            log_weights2 = calc_marginal_is(mean, proj_mat2, heatmap_loader)

    # 3: subspace constructed using mean (from full SWA trajectory) and covariance/deviation from $M$ points obtained through part mean.
    # mean: mean of full SWA trajectory
    # cov_factor: covariance matrix of part mean trajectory.
    if do_method_BA:
        subspace_ = parse_subspaces('BA', num_params, max_rank=M, pca_rank=pca_rank, total_traj_num=traj_sample_size)
        for i in range(traj_sample_size):
            subspace_.collect_vector(full_traj[i])
        proj_mat3, mean = subspace_.get_space()
        s3 = subspace_.singular_values
        proj_samples3 = (torch.diag(torch.from_numpy((M - 1) / s3 ** 2)) @ proj_mat3 @ (full_traj - mean).t()).t()

        if draw_pp:
            draw_posterior_predictive(mean, proj_mat3, title="Block averaging subspace")
        if do_ess_sampling:
            ess_sampling(mean, proj_mat3, num_samples, proposal_var)
        if draw_heatmap_posterior:
            heatmap_posterior_density3 = calc_posterior_heatmap(mean, proj_mat3, heatmap_input, heatmap_loader)
        if draw_heatmap_likelihood:
            heatmap_likelihood_density3 = calc_likelihood_heatmap(mean, proj_mat3, heatmap_input, heatmap_loader)
        if draw_post_int:
            log_weights3 = calc_marginal_is(mean, proj_mat3, heatmap_loader)

    if draw_heatmap_posterior:
        if True:
            stack_density = torch.stack((heatmap_posterior_density1, heatmap_posterior_density2, heatmap_posterior_density3), dim=0).cpu()
            val1 = 0.5 * torch.logdet(proj_mat1 @ proj_mat1.t())
            val2 = 0.5 * torch.logdet(proj_mat2 @ proj_mat2.t())
            val3 = 0.5 * torch.logdet(proj_mat3 @ proj_mat3.t())
            stack_density[0] += val1
            stack_density[1] += val2
            stack_density[2] += val3
            print("Det: %.4f, %.4f, %.4f" % (val1, val2, val3))
            if True:
                torch.save(stack_density, "experiments/synthetic_regression/stack_density_" + str(heatmap_x.size(0)) + "_" + str(heatmap_y.size(0)) +
                           "_log_" + heatmap_str + "11=" + str(M) + "_" + str(args.data) + ".pt")

        stack_density = torch.exp(stack_density - torch.max(stack_density[0]))
        # stack_density = torch.exp(stack_density)
        max_val = torch.max(stack_density[0])
        min_val = torch.min(stack_density[0])
        sum_density = torch.sum(stack_density[0])
        if True:
            for i in range(stack_density.size(0)):
                print(torch.sum(stack_density[i]) / sum_density)
        fig, ax = plt.subplots(1, 3, figsize=(15, 3.55))
        cmap = matplotlib.cm.get_cmap('viridis')
        normalizer = matplotlib.colors.Normalize(min_val, max_val)
        im = matplotlib.cm.ScalarMappable(norm=normalizer, cmap=cmap)
        ax[0].pcolormesh(heatmap_X, heatmap_Y, stack_density[0].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[0].set_title("Full subspace")

        ax[1].pcolormesh(heatmap_X, heatmap_Y, stack_density[1].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[1].set_title("Tail subspace")
        # disable y label for ax[1]
        ax[1].set_yticklabels([])

        ax[2].pcolormesh(heatmap_X, heatmap_Y, stack_density[2].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[2].set_title("Block subspace")
        ax[2].set_yticklabels([])

        fig.colorbar(im, ax=ax.ravel().tolist())
        plt.savefig("figures/heatmap_posterior_" + heatmap_str + "_M=" + str(M) + ".png", dpi=300)
        plt.close()

    if draw_heatmap_likelihood:
        if True:
            stack_density = torch.stack((heatmap_likelihood_density1, heatmap_likelihood_density2, heatmap_likelihood_density3), dim=0).cpu()
            val1 = 0.5 * torch.logdet(proj_mat1 @ proj_mat1.t())
            val2 = 0.5 * torch.logdet(proj_mat2 @ proj_mat2.t())
            val3 = 0.5 * torch.logdet(proj_mat3 @ proj_mat3.t())
            print("Det: %.4f, %.4f, %.4f" % (val1, val2, val3))
            if True:
                torch.save(stack_density, "experiments/synthetic_regression/likelihood_" + str(heatmap_x.size(0)) + "_" + str(heatmap_y.size(0)) +
                           "_log_" + heatmap_str + "=" + str(M) + "_" + str(args.data) + ".pt")

        stack_density = torch.exp(stack_density)
        max_val = torch.max(stack_density[0])
        min_val = torch.min(stack_density[0])
        sum_density = torch.sum(stack_density[0])
        if True:
            for i in range(stack_density.size(0)):
                print(torch.sum(stack_density[i]) / sum_density)
        fig, ax = plt.subplots(1, 3, figsize=(15, 3.55))
        cmap = matplotlib.cm.get_cmap('viridis')
        normalizer = matplotlib.colors.Normalize(min_val, max_val)
        im = matplotlib.cm.ScalarMappable(norm=normalizer, cmap=cmap)
        ax[0].pcolormesh(heatmap_X, heatmap_Y, stack_density[0].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[0].set_title("Full subspace")

        ax[1].pcolormesh(heatmap_X, heatmap_Y, stack_density[1].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[1].set_title("Tail subspace")
        # disable y label for ax[1]
        ax[1].set_yticklabels([])

        ax[2].pcolormesh(heatmap_X, heatmap_Y, stack_density[2].reshape(x_size, y_size).numpy(), cmap=cmap, shading='auto', norm=normalizer)
        ax[2].set_title("Block-avg subspace")
        ax[2].set_yticklabels([])

        fig.colorbar(im, ax=ax.ravel().tolist())
        plt.savefig("figures/heatmap_likelihood_" + heatmap_str + "_M=" + str(M) + ".png", dpi=300)
        plt.close()

    if calc_subspace_angle:
        from scipy.linalg import subspace_angles
        proj_mat_combined = [proj_mat1, proj_mat2, proj_mat3]
        result = np.zeros((3, 3, pca_rank))
        for i in range(0, 3):
            for j in range(0, 3):
                result[i, j] = np.rad2deg(subspace_angles(proj_mat_combined[i].t().numpy(), proj_mat_combined[j].t().numpy()))
        torch.save(torch.from_numpy(result), "ckpts/subspace_angle_" + str(args.data) + ".pt")

        for i in range(0, 3):
            for j in range(i + 1, 3):
                print("i: %d, j: %d" % (i, j))
                print(np.rad2deg(subspace_angles(proj_mat_combined[i].t().numpy(), proj_mat_combined[j].t().numpy())))

        print("%-------------------%")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=int, default=1, help='gpu_available')  # 0: cpu; 1: cuda:0, 2: cuda:1, ...
    parser.add_argument('--M', type=int, default=20, help='M value')
    parser.add_argument('--data', type=int, default=1, help='dataset')
    parser.add_argument('--dbg', type=int, default=2, help='dataset')
    parser.add_argument('--dbg_method1', type=int, default=1, help='dataset')
    parser.add_argument('--dbg_method2', type=int, default=1, help='dataset')
    parser.add_argument('--rs', type=int, default=1, help='dataset')
    parser.add_argument('--seed', type=int, default=1, help='seed')
    parser.add_argument('--re', type=int, default=21, help='dataset')
    args = parser.parse_args()
    main(args)
