"""To visualize GP prior samples."""

import gc
import torch
from matplotlib import pyplot as plt
from data.sampler import (
    multi_task_gp_prior_sampler,
    multi_output_gp_prior_sampler,
)
import numpy as np
from scipy.interpolate import griddata

Sampler_Dict = {
    "Multi_Task_GP_prior": multi_task_gp_prior_sampler,
    "Multi_Output_GP_prior": multi_output_gp_prior_sampler,
}
if __name__ == "__main__":
    # genereate gp prior samples with specified parameters
    B = 8
    lengthscale_range = [1.95, 2.0]
    std_range = [0.1, 1.0]
    x_dim = 1
    y_dim = 2
    x_range = [[-5.0, 5.0] for _ in range(x_dim)]
    data_kernel_type_list = ["rbf", "matern52", "matern32"]
    num_datapoints = 500
    sampler_name = "Multi_Output_GP_prior"

    fig = plt.figure(figsize=(10 * y_dim, 10 * B))
    ncol = y_dim
    nrow = B

    grid_res = 100
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.set_default_device(device)

    for j in range(len(data_kernel_type_list)):
        for b in range(B):
            # Linearly spaced 1-dimensional x for better visualization: [num_datapoints, x_dim]
            if x_dim == 1:
                x = (
                    torch.linspace(x_range[0][0], x_range[0][1], num_datapoints)
                    .unsqueeze(1)
                    .expand(-1, x_dim)
                )
            else:
                x = None
            
            # [num_datapoints, x_dim] and [num_datapoints, y_dim]
            x, y = Sampler_Dict[sampler_name](
                x_range=x_range,
                x_dim=x_dim,
                num_tasks=y_dim,
                x=x,
                num_datapoints=num_datapoints,
                data_kernel_type_list=[data_kernel_type_list[j]],
                sample_kernel_weights=[1],
                lengthscale_range=lengthscale_range,
                std_range=std_range,
                standardize=False,
            )
            x, y = x.detach().cpu().numpy(), y.detach().cpu().numpy()

            # plot each sample
            for i in range(y_dim):
                ax = fig.add_subplot(nrow, ncol, b * y_dim + i + 1)
                ax.set_title(f"dataset {b}")
                if x_dim == 1:
                    ax.plot(x, y[:, i])
                elif x_dim == 2:
                    xi = np.linspace(
                        x[:, 0].min(),
                        x[:, 0].max(),
                        grid_res,
                    )
                    yi = np.linspace(
                        x[:, 1].min(),
                        x[:, 1].max(),
                        grid_res,
                    )
                    xi, yi = np.meshgrid(xi, yi)
                    y_grid = griddata(
                        x,
                        y[:, i],
                        (xi, yi),
                        method="cubic",
                    )
                    _ = ax.contourf(
                        xi, yi, y_grid, levels=20, cmap="viridis", alpha=0.7
                    )

            if device == "cuda":
                gc.collect()
                torch.cuda.empty_cache()

        plt.tight_layout()
        print(
            f"Saving figure for {sampler_name} with kernel {data_kernel_type_list[j]} and x_dim {x_dim}..."
        )
        fig.savefig(
            f"{sampler_name}_{data_kernel_type_list[j]}_dx{x_dim}_l{lengthscale_range[0]}_{lengthscale_range[1]}.png",
            dpi=300,
            bbox_inches="tight",
        )
