import torch
from torch import Tensor
from gpytorch.kernels import RBFKernel, MaternKernel, Kernel
from gpytorch.models import IndependentModelList
from gpytorch.likelihoods import (
    MultitaskGaussianLikelihood,
    GaussianLikelihood,
    LikelihoodList,
)
from gpytorch.settings import fast_pred_var, cholesky_jitter, cholesky_max_tries
import random
import numpy as np
import scipy.stats as sps
from torch.distributions.bernoulli import Bernoulli
import sobol_seq
from data.gpytorch_utils import (
    MultitaskGPModel,
    ExactGPModel,
    RotatedARDKernel,
    sample_orthonormal_matrix,
)
from data.function_preprocessing import make_range_tensor, min_max_scale
from typing import List, Tuple, Optional, Union
from utils.types import FloatListOrNestedOrTensor
from utils.config import get_train_y_range
from einops import repeat, rearrange

Y_RANGE = get_train_y_range()
GRID_SIZE = 1000
DATA_KERNEL_TYPE_LIST = ["rbf", "matern32", "matern52"]
SAMPLE_KERNEL_WEIGHTS = [1, 1, 1]
SAMPLER_WEIGHTS = [1, 1]
L_RANGE = [0.1, 2.0]
STD_RANGE = [0.1, 1.0]
MIN_RANK = 1
P_ISO = 0.5
JITTER = 1e-6
MAX_TRIES = 6


def create_q_mask(factorize: bool, dim: int, device: str) -> Tensor:
    if factorize:
        # Any dimension is sampled independently, so all dimensions are masked
        q_mask = torch.ones(dim, device=device, dtype=torch.bool)
    else:
        # None of dimensions are sampled independently, so all dimensions are masked
        q_mask = torch.zeros(dim, device=device, dtype=torch.bool)

    return q_mask


def scale_y(y, domains: list = Y_RANGE):
    y_maxs = torch.max(y, dim=-2).values
    y_mins = torch.min(y, dim=-2).values

    domains_t = torch.tensor(domains)

    y_scaled = min_max_scale(
        data=y,
        mins=y_mins,
        maxs=y_maxs,
        sigma=0.0,
        target_bounds=domains_t,
    )
    return y_scaled


def _get_data_kernel(kernel_type: str, x_dim: int) -> Kernel:
    """Get data kernel with separate lengthscales for each dimension."""
    if kernel_type == "rbf":
        kernel = RBFKernel(ard_num_dims=x_dim)
    elif kernel_type == "matern12":
        kernel = MaternKernel(nu=0.5, ard_num_dims=x_dim)
    elif kernel_type == "matern32":
        kernel = MaternKernel(nu=1.5, ard_num_dims=x_dim)
    elif kernel_type == "matern52":
        kernel = MaternKernel(nu=2.5, ard_num_dims=x_dim)
    elif kernel_type == "rotated_ard":
        kernel = RotatedARDKernel(num_dims=x_dim)
    else:
        raise NotImplementedError(f"Unsupported data kernel type: {kernel_type}")

    return kernel


def _sample_lengthscale(
    x_dim: int,
    lengthscale_range: List,
    p_iso: float,
    mu=np.log(2 / 3),
    sigma=0.5,
) -> Tensor:
    """Sample x_dim lengthscales."""
    # Sample from truncated normal distribution
    a = (np.log(lengthscale_range[0]) - mu) / sigma
    b = (np.log(lengthscale_range[1]) - mu) / sigma

    rv = sps.truncnorm(a, b, loc=mu, scale=sigma)
    lengthscale = Tensor(np.exp(rv.rvs(size=(x_dim))))

    # Scale lengthscale with x dims: http://arxiv.org/abs/2402.02229
    lengthscale *= torch.tensor(x_dim, device=lengthscale.device).sqrt()

    if x_dim > 1:
        # isotropic kernel with same lengthscale for each dimension
        is_iso = Bernoulli(p_iso).sample()  # (1)
        if is_iso:
            lengthscale[:] = lengthscale[0]

    return lengthscale


def sample_x(x_range: Tensor, num_datapoints: int, grid: bool) -> Tensor:
    """Sample from a Sobol sequence over given input domains.
    
    Args:
        x_range: input domains, [x_dim, 2]
        num_datapoints: number of datapoints to sample
        grid: sample from a grid or randomly over input domains

    Returns:
        x_scaled: [num_datapoints, x_dim]
    """
    assert x_range.ndim == 2 and x_range.shape[1] == 2
    dim_num = x_range.shape[0]
    tkwargs = {"device": x_range.device, "dtype": torch.float32}
    
    if grid: 
        num_sobol_samples = num_datapoints
    else: 
        num_sobol_samples = max(GRID_SIZE, num_datapoints)

    # Generate Sobol samples in [0,1]^d
    sobol_np = sobol_seq.i4_sobol_generate(dim_num=dim_num, n=num_sobol_samples)
    sobol_samples = torch.from_numpy(sobol_np).to(**tkwargs)

    # Permute Sobol samples randomly
    perm_indices = torch.randperm(num_sobol_samples, device=x_range.device)
    samples_perm = sobol_samples[perm_indices]

    # Extract needed number of samples
    samples_perm = samples_perm[:num_datapoints]

    # Scale samples
    x_min, x_max = x_range[:, 0], x_range[:, 1]
    x_scaled = x_min + (x_max - x_min) * samples_perm
    x_scaled = torch.max(torch.min(x_scaled, x_max), x_min)

    return x_scaled.detach()


def _check_samples(
    x: Tensor,
    y: Tensor,
    kernel_type_list: List[str],
    lengthscale_list: List[str],
    std_list: List[Tensor],
    covar_list: List[Tensor],
):
    """Check if the samples contain NaNs or Infs."""
    num_kernels = len(kernel_type_list)

    if not torch.all(torch.isfinite(x)):
        print(f"x contains NaNs or Infs: {x.shape}\n{x}")

    if not torch.all(torch.isfinite(y)):
        print(f"y contains NaNs or Infs: {y.shape}\n{y}")
        print(f"x: {x.shape}\n{x}")

        for k in range(num_kernels):
            kernel_type = kernel_type_list[k]
            lengthscale = lengthscale_list[k]
            std = std_list[k]
            cov = covar_list[k]

            print(f"[Model {k}]")
            print(f"data kernel: {kernel_type}")
            print(f"lengthscale: {lengthscale}")
            print(f"std: {std}")

            if torch.isnan(cov).any():
                print("NaN in kernel matrix")

            eigvals = torch.linalg.eigvalsh(cov)
            print("min eig:", eigvals.min().item())
            print("max eig:", eigvals.max().item())
            print("condition number:", (eigvals.max() / eigvals.min()).item())


def multi_task_gp_prior_sampler(
    x_range: Union[List, Tensor],
    x_dim: int,
    num_datapoints: int,
    num_tasks: int,
    data_kernel_type_list: List = DATA_KERNEL_TYPE_LIST,
    sample_kernel_weights: List = SAMPLE_KERNEL_WEIGHTS,
    lengthscale_range: List = L_RANGE,
    std_range: List = STD_RANGE,
    min_rank: int = MIN_RANK,  # Lower rank, higher task correlation
    max_rank: Optional[int] = None,
    p_iso: float = P_ISO,  # Probability of using isotropic kernel
    standardize: bool = True,  # removed, for API compatibility
    grid: bool = False,
    device: str = "cuda",
    x: Optional[Tensor] = None,
    jitter: float = JITTER,
    max_tries: int = MAX_TRIES,
    **kwargs,
) -> Tuple[Tensor, Tensor]:
    """Sample from multi-task gp priors with a single data kernel and a task kernel."""
    assert lengthscale_range[0] > 0
    assert std_range[0] > 0
    assert x_dim == len(x_range), "Unmatched `x_dim` and `x_range`."

    if isinstance(x_range, list):
        x_range = torch.tensor(x_range)
    x_range = x_range.to(device)

    # Sample inputs: [num_datapoints, x_dim]
    if x is None:
        x = sample_x(x_range=x_range, num_datapoints=num_datapoints, grid=grid)

    # Sample data kernel
    data_kernel_type = random.choices(
        population=data_kernel_type_list, weights=sample_kernel_weights, k=1
    )[0]
    data_kernel = _get_data_kernel(kernel_type=data_kernel_type, x_dim=x_dim)

    # Sample lengthscales for data kernel: [x_dim]
    lengthscale = _sample_lengthscale(
        x_dim=x_dim, lengthscale_range=lengthscale_range, p_iso=p_iso
    ).to(device)

    # Sample std for task kernel: [num_tasks]
    std = torch.rand(num_tasks, device=device)
    std = std * (std_range[1] - std_range[0]) + std_range[0]

    # Sample rank for task kernel
    if max_rank is not None:
        assert max_rank <= num_tasks, "`max_rank` should be no more than `num_tasks`."
    else:
        max_rank = num_tasks
    rank = random.randint(min_rank, max_rank)

    # Setup likelihood and model
    likelihood = MultitaskGaussianLikelihood(num_tasks=num_tasks)
    model = MultitaskGPModel(
        train_x=None,
        train_y=None,
        likelihood=likelihood,
        kernel=data_kernel,
        num_tasks=num_tasks,
        rank=rank,
    )
    if data_kernel_type == "rotated_ard":
        sampled_R = sample_orthonormal_matrix(x_dim).to(device)
        model.covar_module.data_covar_module.raw_lengthscales.data = lengthscale
        model.covar_module.data_covar_module.R.data = sampled_R

    else:
        # Set lengthscales for the data kernel
        model.covar_module.data_covar_module.lengthscale = lengthscale

    # Set different variance for for different tasks by adjusting `v` vector in IndexKernel
    # `var` would be element-wise SoftPlus of passed values.
    model.covar_module.task_covar_module.var = std**2

    # Set up model and likelihood to evaluation mode and move to the correct dtype
    model.eval()
    likelihood.eval()

    model.to(x)
    likelihood.to(x)

    # sample from the prior distribution
    with torch.no_grad(), fast_pred_var(), cholesky_jitter(jitter), cholesky_max_tries(
        max_tries
    ):
        prior_dist = model(x)
        y = prior_dist.sample(torch.Size([1])).squeeze(0)  # [num_datapoints, num_task]

    _check_samples(
        x=x,
        y=y,
        kernel_type_list=[data_kernel_type],
        lengthscale_list=[lengthscale],
        std_list=[std],
        covar_list=[model.covar_module(x).evaluate()],
    )

    if standardize:
        y = scale_y(y)

    # Free up memory
    model = model.cpu()
    likelihood = likelihood.cpu()
    model.eval()
    del model, likelihood

    return x, y


def multi_output_gp_prior_sampler(
    x_range: Union[List, Tensor],
    x_dim: int,
    num_tasks: int,
    num_datapoints: int,
    data_kernel_type_list: List = DATA_KERNEL_TYPE_LIST,
    sample_kernel_weights: List = SAMPLE_KERNEL_WEIGHTS,
    lengthscale_range: List = L_RANGE,
    std_range: List = STD_RANGE,
    p_iso: float = P_ISO,
    standardize: bool = True,  # removed, for API compatibility
    grid: bool = False,
    device: str = "cuda",
    x: Optional[Tensor] = None,
    jitter: float = JITTER,
    max_tries: int = MAX_TRIES,
    **kwargs,
):
    """Sample from multi-output gp prior, with independent outputs, different kernels for each."""
    assert lengthscale_range[0] > 0
    assert std_range[0] > 0
    assert x_dim == len(x_range), "Unmatched `x_dim` and `x_range`."

    if isinstance(x_range, list):
        x_range = torch.tensor(x_range)
    x_range = x_range.to(device)

    # Sample inputs: [num_datapoints, x_dim]
    if x is None:
        x = sample_x(x_range=x_range, num_datapoints=num_datapoints, grid=grid)

    # Sample data kernel for each task
    data_kernel_type = random.choices(
        population=data_kernel_type_list, weights=sample_kernel_weights, k=num_tasks
    )

    models = []
    likelihoods = []

    # Sample data kernel for each task
    for _, kernel_type in enumerate(data_kernel_type):
        # Sample lengthscale: [x_dim]
        lengthscale = _sample_lengthscale(
            x_dim=x_dim, lengthscale_range=lengthscale_range, p_iso=p_iso
        ).to(device)

        # Sample std: [1]
        std = torch.rand(1, device=device)
        std = std * (std_range[1] - std_range[0]) + std_range[0]

        # Sample data kernel
        data_kernel = _get_data_kernel(kernel_type=kernel_type, x_dim=x_dim)
        if kernel_type == "rotated_ard":
            sampled_R = sample_orthonormal_matrix(x_dim).to(device)
            data_kernel.raw_lengthscales.data = lengthscale
            data_kernel.R.data = sampled_R
        else:
            data_kernel.lengthscale = lengthscale

        # Setup likelihood and model
        likelihood = GaussianLikelihood()
        likelihoods.append(likelihood)
        model = ExactGPModel(
            kernel=data_kernel, likelihood=likelihood, train_x=None, train_y=None
        )
        model.covar_module.outputscale = std**2
        models.append(model)

    # Setup likelihood and model lists
    model = IndependentModelList(*models)
    likelihood = LikelihoodList(*likelihoods)

    model.eval()
    likelihood.eval()

    model.to(x)
    likelihood.to(x)

    # Sample from the prior distribution
    with torch.no_grad(), fast_pred_var(), cholesky_jitter(jitter), cholesky_max_tries(
        max_tries
    ):
        # This contains predictions for all models' outcomes as a list
        prior_dist_list = model(*[x for _ in range(num_tasks)])

        # num_tasks x [(num_datapoints)] -> [num_datapoints, num_tasks]
        ys = [
            prior_dist.sample(torch.Size([1])).squeeze(0)
            for prior_dist in prior_dist_list
        ]
        y = torch.stack(ys, dim=-1)

    # Check if the sampled outputs contain NaNs or Infs
    _check_samples(
        x=x,
        y=y,
        kernel_type_list=[
            submodel.covar_module.base_kernel for submodel in model.models
        ],
        lengthscale_list=[
            submodel.covar_module.base_kernel.lengthscale for submodel in model.models
        ],
        std_list=[submodel.covar_module.outputscale for submodel in model.models],
        covar_list=[submodel.covar_module(x).evaluate() for submodel in model.models],
    )

    if standardize:
        y = scale_y(y)

    # Free up memory
    model = model.cpu()
    likelihood = likelihood.cpu()
    model.eval()
    del model, likelihood
    return x, y


def _discretize_joint_subspace(x_range: Tensor, d: int, grid: bool) -> Tensor:
    """Discretise D-dimensional search space into d points of D features.

    Args:
        x_range (Tensor): range of each dimension, shape [D, 2]
        d (int): number of samples
        grid (bool): whether to sample from a grid or random locations

    Returns:
        x (Tensor): sampled points in the joint subspace, shape [1, d, D]
    """
    x = sample_x(x_range=x_range, num_datapoints=d, grid=grid).unsqueeze(0)
    return x


def _discretize_ind_subspaces(x_range: Tensor, d: int, grid: bool) -> Tensor:
    """Discretise D-dimensional search space into d points of 1 feature for each subspace.

    Args:
        x_range (Tensor): range of each dimension, shape [D, 2]
        B (int): batch size
        d (int): number of samples in each subspace
        grid (bool): whether to sample from a grid or random locations

    Returns:
        x (Tensor): sampled points in the independent subspaces, shape [D, d, 1]
    """
    D = x_range.shape[0]

    # Sample x in each dimension: D x [d, 1]
    x_list = [
        sample_x(
            x_range=x_range[[index], :],  # [1, 2]
            num_datapoints=d,
            grid=grid,
        )
        for index in range(D)
    ]

    x = torch.stack(x_list, dim=0)
    return x


def sample_subspaces_efficient(
    x_range: Tensor,
    x_dim_mask: Tensor,
    q_dim_mask: Tensor,
    d: int,
    grid: bool,
    **kwargs,
) -> Tuple[Tensor, Tensor]:
    """Sample from a search space (efficient version)

    Args:
        x_range (Tensor): range of each dimension, shape [dx_max, 2]
        x_dim_mask (Tensor): mask for x dimensions that are valid, shape [dx_max]
        q_dim_mask (Tensor): mask for x dimensions that are sampled independently, shape [dx_max]
        d (int): number of samples in each subspace
        grid (bool): whether to sample from a grid or random locations

    Returns:
        chunks [d, dx_max], chunk_mask [n, dx_max]
    """
    # Key difference from `sample_subsplace`
    # Since no overlapping valid dimensions in chunks
    # we can use `chunks` shaped [d, dx_max] rather than `[d, n, dx_max]` for all points
    # as long as `chunk_mask` shape d[n, dx_max]` records valid dims in chunks
    dx_max = x_dim_mask.shape[-1]

    # Find valid dimensions for independent and joint subspace: [dx_max]
    valid_q_ind = q_dim_mask & x_dim_mask
    valid_q_joint = (~q_dim_mask) & x_dim_mask

    # Count valid dimensions
    count_ind = valid_q_ind.int().sum().item()
    count_joint = valid_q_joint.int().sum().item()
    count_chunk = count_ind + int(count_joint > 0)

    # Get indices of valid dimensions for each subspace type: [count_ind] and [count_joint]
    dims = torch.arange(dx_max, device=x_range.device)
    dims_ind = dims[valid_q_ind]
    dims_joint = dims[valid_q_joint]

    # Initialize chunks and mask
    chunks = torch.zeros([d, dx_max], device=x_range.device, dtype=torch.float32)

    chunk_mask = torch.zeros(
        [count_chunk, dx_max], device=x_range.device, dtype=torch.bool
    )

    # Discretise joint search space (last chunk)
    if dims_joint.numel() > 0:
        # Sample points from joint subspace: [B, d, count_joint]
        x_range_joint = x_range[valid_q_joint, :]

        # [1, d, count_joint] -> [d, count_joint]
        x_joint = _discretize_joint_subspace(x_range=x_range_joint, d=d, grid=grid)
        x_joint = x_joint.squeeze(0)

        # Expand dim indices: [count_joint] -> [d, count_joint]
        dims_joint_expanded = repeat(dims_joint, "dim -> d dim", d=d)

        # Scatter x_joint (d, count_joint) into chunks (d, dx_max)
        chunks.scatter_(dim=-1, index=dims_joint_expanded, src=x_joint)

        # Set mask for joint subspace
        chunk_mask[-1:, dims_joint] = True

    # Discretise independent search spaces (first count_ind chunks)
    if dims_ind.numel() > 0:
        # Sample points: [B, count_ind, d, 1]
        x_range_ind = x_range[valid_q_ind, :]

        # [count_ind, d, 1] -> [count_ind, d] -> [d, count_ind]
        x_ind = _discretize_ind_subspaces(x_range=x_range_ind, d=d, grid=grid)
        x_ind = x_ind.squeeze(2)  # Remove last dimension
        x_ind = rearrange(x_ind, "dim d -> d dim")

        # Expand dims indices: [count_ind] -> [d, count_ind]
        dims_ind_expanded = repeat(dims_ind, "dim -> d dim", d=d)

        # Scatter x_ind (d, count_ind) into chunks (d, dx_max)
        chunks.scatter_(dim=-1, index=dims_ind_expanded, src=x_ind)

        # Set mask for `count_ind` independent subspaces
        indices_ind = torch.arange(count_ind, device=x_range.device)
        chunk_mask[indices_ind, dims_ind] = True

    return chunks, chunk_mask


def get_sample_indices_from_chunk_indices(
    chunk_indices: Tensor,  # [B, M, N]
    chunks: Tensor,  # [B, N, D, H]
) -> Tensor:  # [B, M, 1]
    B, N, D, H = chunks.shape

    # Each index vector as a base-D number: [d0, d1, d2] -> d0 * D^{N-1} + d1 * D^{N-2} + d2 * D^0
    powers = D ** torch.arange(N - 1, -1, -1, device=chunk_indices.device)  # [N]

    # [B, M, N] -> [B, M, 1]
    full_indices = (chunk_indices * powers).sum(dim=-1, keepdim=True)
    return full_indices


def get_all_sample_from_chunks_efficient(
    chunks: Tensor, chunk_mask: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    """Full combination of chunks into samples on the full search space (efficient version).
    NOTE Pattern exists in outputs - be careful with slices.

    Args:
        chunks (Tensor): chunks with non-overlapping valid dimensions, shape [d, dx_max]
        chunk_mask (Tensor): mask for chunks, shape [N, dx_max]

    Returns:
        samples [M, dx_max], mask [dx_max], chunk_coord_grid [M, N]
    """
    assert chunk_mask.shape[-1] == chunks.shape[-1]
    assert chunk_mask.ndim == 2 and chunks.ndim == 2
    assert torch.all(
        (chunk_mask.int().sum(dim=0) <= 1)
    )  # Check no overlapping dimensions

    d, dx_max = chunks.shape
    N, _ = chunk_mask.shape

    # Element indices for each chunk: N x [d]
    chunk_index_ranges = [
        torch.arange(d, device=chunks.device, dtype=torch.long) for _ in range(N)
    ]

    # Cartesian product of the indices: [M, N]
    # combo_indices[m] = [d_0, ..., d_N], where d_i in [0, d) is the index in chunk i.
    chunk_coord_grid = torch.cartesian_prod(*chunk_index_ranges)

    # NOTE cartesian_prod returns a flatten vector when n=1, so reshape
    chunk_coord_grid = chunk_coord_grid.view(-1, N)
    M = d**N
    assert chunk_coord_grid.shape == (M, N)

    samples = torch.zeros(M, dx_max, device=chunks.device, dtype=chunks.dtype)
    for i in range(N):
        # The valid dimensions for the i-th chunk: [dx_max] -> [d, dx_max]
        mask = chunk_mask[i]  # [dx_max]
        mask_expanded = repeat(mask, "dx_max -> d dx_max", d=d)

        # Compute the i-th chunk by taking corresponding valid dimensions: [d, dx_max]
        chunk = chunks * mask_expanded

        # Take the indice from chunk_coord_grid for the i-th chunk: [M]
        coord_grid = chunk_coord_grid[:, i]

        # Input: chunk shape [d, dx_max]
        # Input: coord_grid shape [M]
        # Output: chunk_sample shape [M, dx_max]
        # chunk_sample[m] = chunks[coord_grid[m]]

        # Gather the chunk samples for the i-th chunk: [B, M, dx_max]
        # Add up gathered chunks since there are not overlapping dims
        samples += chunk[coord_grid]

    # Take OR over chunk masks as there are not overlapping dims: [dx_max]
    mask = chunk_mask.any(dim=0)

    return samples, mask, chunk_coord_grid


def factorized_sampler(
    B: int,
    d: int,  # Number of bins
    grid: bool,
    max_x_dim: int,
    max_y_dim: int,
    x_dim_list: List,
    y_dim_list: List,
    x_range: List,  # max_x_dim x [2]
    factorize: bool,
    device: str,
    x_dim: Optional[int] = None,
    y_dim: Optional[int] = None,
    sampler_list: list = [
        "multi_task_gp_prior_sampler",
        "multi_output_gp_prior_sampler",
    ],
    sampler_weights: List = SAMPLER_WEIGHTS,
    data_kernel_type_list: List = DATA_KERNEL_TYPE_LIST,
    sample_kernel_weights: List = SAMPLE_KERNEL_WEIGHTS,
    lengthscale_range: List = L_RANGE,
    std_range: List = STD_RANGE,
    min_rank: int = MIN_RANK,
    max_rank: Optional[int] = None,
    p_iso: float = P_ISO,
    dim_scatter_mode: str = "random_k",
    jitter: float = JITTER,
    max_tries: int = MAX_TRIES,
    **kwargs,
) -> Tuple[
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
]:
    """Gaussian Process prior sampler over (factorized) subspaces of the search space.

    Args:
        B (int): batch size
        d (int): number of query points in each subspace
        grid (bool): whether to sample from a grid or random locations
        max_x_dim (int): maximum number of x dimensions
        max_y_dim (int): maximum number of y dimensions
        x_dim_list (List): list of possible valid x dimensions
        y_dim_list (List): list of possible valid y dimensions
        x_range (List): range of each dimension, max_x_dim x [2]
        factorize (bool): whether to sample from factorized subspaces
        device (str): "cuda" or "cpu"
        x_dim (Optional[int]): number of x dimensions to sample, if None, randomly select from `x_dim_list`
        y_dim (Optional[int]): number of y dimensions to sample, if None, randomly select from `y_dim_list`
        sampler_list (list): list of gp sampler functions to use
        sampler_weights (List): weights for each sampler in `sampler_list`
        data_kernel_type_list (List): list of data kernel types to sample from
        sample_kernel_weights (List): weights for each data kernel type in `data_kernel_type_list`
        lengthscale_range (List): range of lengthscales for data kernel
        std_range (List): range of standard deviations for task kernel
        min_rank (int): minimum rank for task kernel
        max_rank (Optional[int]): maximum rank for task kernel, if None, set to num_tasks
        p_iso (float): probability of using isotropic kernel
        dim_scatter_mode (str): "random_k" or "first_k", how to select valid dimensions
        jitter (float): jitter for Cholesky decomposition
        max_tries (int): maximum number of tries to sample from the prior

    Returns:
        y_stacked (Tensor): sampled function values, shape [B, M, max_y_dim]
        x_mask (Tensor): mask indicating valid x dims, shape [max_x_dim]
        y_mask (Tensor): mask indicating valid y dims, shape [max_y_dim]
        chunks (Tensor): sampled points, shape [d, max_x_dim]
        chunk_mask (Tensor): mask indicating valid x dims in each chunk, shape [n, max_x_dim]
    """
    for sampler in sampler_list:
        assert sampler in [
            "multi_output_gp_prior_sampler",
            "multi_task_gp_prior_sampler",
        ], f"sampler `{sampler}` is not supported."
    assert len(x_range[0]) == 2
    assert max(x_dim_list) <= max_x_dim
    assert max(y_dim_list) <= max_y_dim
    assert (
        len(x_range) == max_x_dim or len(x_range) == 1
    ), f"x_range must be of length max_x_dim or 1, got {len(x_range)}"

    # Create sampler dictionary
    sampler_dict = {
        "multi_task_gp_prior_sampler": multi_task_gp_prior_sampler,
        "multi_output_gp_prior_sampler": multi_output_gp_prior_sampler,
    }

    # Create q_mask: [max_x_dim]
    q_mask = create_q_mask(
        factorize=factorize,
        dim=max_x_dim,
        device=device,
    )

    # Get x_range: max_x_dim x [2] -> [max_x_dim, 2]
    x_range = [x_range[0] for _ in range(max_x_dim)] if len(x_range) == 1 else x_range
    x_range_tensor = torch.tensor(x_range, device=device)

    # Valid x_dim and y_dim for the batch
    x_dim = random.choice(x_dim_list) if x_dim is None else x_dim
    y_dim = random.choice(y_dim_list) if y_dim is None else y_dim

    # Create masks: [max_x_dim] and [max_y_dim]
    x_mask = torch.zeros(max_x_dim, dtype=torch.bool, device=device)
    y_mask = torch.zeros(max_y_dim, dtype=torch.bool, device=device)

    if dim_scatter_mode == "random_k":
        # Randomly select k valid dimensions
        x_perm = torch.argsort(torch.rand(max_x_dim, device=device))
        y_perm = torch.argsort(torch.rand(max_y_dim, device=device))

        x_mask[x_perm[:x_dim]] = True
        y_mask[y_perm[:y_dim]] = True
    else:
        # Otherwise use first k dimensions
        x_mask[:x_dim] = True
        y_mask[:y_dim] = True

    # Sample chunks and mask: [d, max_x_dim] and [n, max_x_dim]
    chunks, chunk_mask = sample_subspaces_efficient(
        x_range=x_range_tensor,
        x_dim_mask=x_mask,
        q_dim_mask=q_mask,
        d=d,
        grid=grid,
    )
    assert chunks.shape == (d, max_x_dim), f"{chunks.shape}"
    assert chunk_mask.shape == (chunks.shape[0], max_x_dim), f"{chunk_mask.shape}"

    assert not torch.isnan(chunks).any(), f"q_chunks has NaNs: {chunks}"
    assert not torch.isinf(chunks).any(), f"q_chunks has infs: {chunks}"

    # Combine chunks into samples: [M, max_x_dim]
    x = get_all_sample_from_chunks_efficient(chunks=chunks, chunk_mask=chunk_mask)[0]
    assert x.ndim == 2, f"x should be 2D, got {x.ndim}D"
    M = x.shape[0]

    # NOTE Same points in a batch: [B, M, max_x_dim]
    x = x.unsqueeze(0).expand(B, -1, -1)

    # Indices of valid x dimensions: [x_dim]
    dx_valid = x_mask.nonzero().squeeze(-1)

    # Valid x and x_range: [B, M, x_dim] and [x_dim, 2]
    x_valid = x[:, :, dx_valid]
    x_range_valid = x_range_tensor[dx_valid]

    # Sample from gp: [B, M, max_y_dim]
    y_list = []
    for b in range(B):
        # Sample a sampler function
        sampler = random.choices(population=sampler_list, weights=sampler_weights, k=1)
        sampler_func = sampler_dict[sampler[0]]

        y = torch.zeros((M, max_y_dim), device=x.device, dtype=x.dtype)

        # Sample until valid y values are found
        while True:
            _, y_valid = sampler_func(
                x=x_valid[b],  # [M, x_dim]
                x_range=x_range_valid,
                x_dim=x_dim,
                num_datapoints=M,
                num_tasks=y_dim,
                data_kernel_type_list=data_kernel_type_list,
                sample_kernel_weights=sample_kernel_weights,
                lengthscale_range=lengthscale_range,
                std_range=std_range,
                min_rank=min_rank,
                max_rank=max_rank,
                p_iso=p_iso,
                grid=grid,
                device=q_mask.device,
                jitter=jitter,
                max_tries=max_tries,
            )
            if not torch.isnan(y_valid).any() and not torch.isinf(y_valid).any():
                break

        # Scatter valid y values into the output tensor
        y[:, y_mask] = y_valid
        y_list.append(y)

    y_stacked = torch.stack(y_list, dim=0)  # [B, M, max_y_dim]
    return (
        y_stacked.float(),  # [B, M, max_y_dim]
        x_mask,  # [max_x_dim]
        y_mask,  # [max_y_dim]
        chunks.float(),  # [d, max_x_dim]
        chunk_mask,  # [n, max_x_dim]
    )


def sample_nc(x_dim: int, min_nc: int = 2, max_nc: int = 50, warmup: bool = False):
    """Sample context size based on number of dimensions:
    - warmup: use max_nc for stable training
    - Otherwise, sample nc from [min_nc, scale_factor * max_nc]

    warmup:
        num_dim=1: nc = 50
        num_dim=2: nc = 100
        num_dim=3: nc = 100
        num_dim=4: nc = 200
    Otherwise:
        num_dim=1: nc from [2, 50]
        num_dim=2: nc from [2, 100]
        num_dim=3: nc from [2, 100]
        num_dim=4: nc from [2, 200]
    """
    scale_factor = 1
    if 1 < x_dim <= 3:
        scale_factor = 2
    elif x_dim > 3:
        scale_factor = 4

    max_nc_scaled = int(max_nc * scale_factor)
    if warmup:
        nc = max_nc_scaled
    else:
        nc = random.randint(min_nc, max_nc_scaled)

    return nc


def get_num_categories(x_dim: int, factorize: bool = True):
    """Decide number of categories for each subspace."""
    assert 0 < x_dim <= 4, "Only support x_dim in [1, 2, 3, 4]"

    if factorize:
        num_categories_dict = {1: 128, 2: 32, 3: 32, 4: 32}  # NOTE when x_dim=1
    else:
        num_categories_dict = {1: 128, 2: 128, 3: 256, 4: 512}

    return num_categories_dict[x_dim]


def gp_sampler(
    x: Tensor,
    x_range: FloatListOrNestedOrTensor,
    y_dim: int,
    sampler_list: list,
    sampler_weights: list,
    data_kernel_type_list: list,
    sample_kernel_weights: list,
    lengthscale_range: tuple,
    std_range: tuple,
    min_rank: int,
    max_rank: int,
    p_iso: float,
    grid: bool,
    jitter: float,
    max_tries: int,
    standardize: bool,
    device: str = "cuda",
):
    """Sample batch from GP priors.

    Args:
        x: [B, N, x_dim]
        x_range: input ranges
        y_dim: number of tasks
        **kwargs: additional arguments for gp priors

    Returns: [B, N, y_dim]
    """
    B, N, x_dim = x.shape

    x_range_t = make_range_tensor(x_range, x_dim).to(device)
    assert x_range_t.shape == (x_dim, 2)

    # Create sampler dictionary
    sampler_dict = {
        "multi_task_gp_prior_sampler": multi_task_gp_prior_sampler,
        "multi_output_gp_prior_sampler": multi_output_gp_prior_sampler,
    }

    y_list = []
    for b in range(B):
        # Sample a sampler function
        sampler = random.choices(population=sampler_list, weights=sampler_weights, k=1)
        sampler_func = sampler_dict[sampler[0]]

        # Sample until y values are valid (no NaNs or Infs): [M, y_dim]
        while True:
            _, y = sampler_func(
                x=x[b],
                x_range=x_range_t,
                x_dim=x_dim,
                num_datapoints=N,
                num_tasks=y_dim,
                data_kernel_type_list=data_kernel_type_list,
                sample_kernel_weights=sample_kernel_weights,
                lengthscale_range=lengthscale_range,
                std_range=std_range,
                min_rank=min_rank,
                max_rank=max_rank,
                p_iso=p_iso,
                grid=grid,
                standardize=standardize,
                jitter=jitter,
                max_tries=max_tries,
                device=device,
            )

            if not torch.isnan(y).any() and not torch.isinf(y).any():
                break

        y_list.append(y)

    y_stacked = torch.stack(y_list, dim=0)  # [B, M, y_dim]

    return y_stacked
