"""Sampling utilities for function environments."""

import torch
from torch import Tensor
from einops import rearrange, repeat
import sobol_seq
from typing import Tuple, Optional
from utils.types import FloatListOrNestedOrTensor
from data.function_preprocessing import make_range_tensor
from data.data_masking import get_q_mask

GRID_SIZE = 1000


def _sample_joint_subspace(x_range: Tensor, d: int, grid: bool) -> Tensor:
    """Sample over joint space -> x: [d, x_dim]"""
    x = generate_sobol_samples(x_range=x_range, num_datapoints=d, grid=grid)
    return x


def _sample_ind_subspaces(x_range: Tensor, d: int, grid: bool) -> Tensor:
    """Sample each dim independently -> x: [d, x_dim]"""
    D = x_range.shape[0]

    # Sample each dimension: D x [d, 1] -> [D, d, 1]
    x_list = [
        generate_sobol_samples(x_range=x_range[[index], :], num_datapoints=d, grid=grid)
        for index in range(D)
    ]
    x = torch.stack(x_list, dim=0)
    x = x.squeeze(2)
    x = rearrange(x, "D d -> d D")
    return x


def _sample_subspace_n_scatter_(
    sample_joint: bool,
    chunks: Tensor,
    chunk_mask: Tensor,
    x_range: Tensor,
    x_dim_indices: Tensor,
    chunk_index_slice: Tensor | slice,
    d: int,
    grid: bool,
):
    """Sample and in-place scatter into `chunks` and set `chunk_index_slice` in `chunk_mask` given `x_dim_indices`.

    Args:

        sample_joint: Sample jointly or independently
        chunks: [d, dx_max]
        chunk_mask: [count_chunk, dx_max]
        x_range: [count_subspace, 2], range for subspaces
        x_dim_indices: [count_subspace], dimension indices for subspaces
        chunk_index_slice: [count_chunk], chunk indices to scatter into
        d (int): Number of samples
        grid: Sample from a grid or randomly

    Leads to in-place changes of `chunks` and `chunk_mask`.
    """
    # Sample x from subspace: [d, dim]
    if sample_joint:
        x = _sample_joint_subspace(x_range=x_range, d=d, grid=grid)
    else:
        x = _sample_ind_subspaces(x_range=x_range, d=d, grid=grid)

    # Scatter x into chunks
    x_dim_indices_exp = repeat(x_dim_indices, "c -> d c", d=d)
    chunks.scatter_(dim=-1, index=x_dim_indices_exp, src=x)

    # Set mask for the subspace as True
    chunk_mask[chunk_index_slice, x_dim_indices] = True


def sample_factorized_space_efficient(
    x_range: FloatListOrNestedOrTensor,
    x_dim_mask: Tensor,
    q_dim_mask: Tensor,
    num_subspace_points: int,
    use_grid_sampling: bool,
) -> Tuple[Tensor, Tensor]:
    """Sample from a factorized input space (efficient version).

    Since there are no overlapping valid dimensions in chunks,
    we can store all chunks in a single tensor of shape [num_subspace_points, dx_max],
    then use a mask of shape [count_chunk, dx_max] to indicate valid dimensions in each chunk.

    Args:
        x_range: Input range, shape [x_dim, 2]
        x_dim_mask: Mask for valid x dims, shape [x_dim]
        q_dim_mask: Mask for dims that are independently queried, shape [x_dim]
        num_subspace_points: number of samples in each subspace
        use_grid_sampling: whether to sample from a grid or random locations

    Returns:
        chunks: [num_subspace_points, dx_max]
        chunk_mask: [count_chunk, dx_max]
    """
    device = x_dim_mask.device
    dx_max = x_dim_mask.shape[-1]

    dims = torch.arange(dx_max, device=device)
    x_range_t = make_range_tensor(x_range, num_dim=dx_max).to(device)  # [dx_max, 2]

    # Find valid dim masks for independent and joint subspaces
    ind_mask = q_dim_mask & x_dim_mask  # [dx_max]
    joint_mask = (~q_dim_mask) & x_dim_mask  # [dx_max]

    # Count valid dims
    count_ind = ind_mask.int().sum().item()
    count_joint = joint_mask.int().sum().item()
    count_chunk = count_ind + int(count_joint > 0)

    # Initialize chunks and mask
    chunks = torch.zeros(
        [num_subspace_points, dx_max], device=device, dtype=torch.float32
    )
    chunk_mask = torch.zeros([count_chunk, dx_max], device=device, dtype=torch.bool)

    # Sample joint space and scatter into chunks[-1:]
    if count_joint > 0:
        index_slice_joint = slice(-1, None)
        x_range_joint = x_range_t[joint_mask, :]  # [count_joint, 2]
        dims_joint = dims[joint_mask]  # [count_joint]

        _sample_subspace_n_scatter_(
            sample_joint=True,
            x_range=x_range_joint,
            x_dim_indices=dims_joint,
            chunk_index_slice=index_slice_joint,
            chunks=chunks,
            chunk_mask=chunk_mask,
            d=num_subspace_points,
            grid=use_grid_sampling,
        )

    # Sample independent spaces and scatter into chunks[:count_ind]
    if count_ind > 0:
        index_slice_ind = torch.arange(count_ind, device=device)
        x_range_ind = x_range_t[ind_mask, :]  # [count_ind, 2]
        dims_ind = dims[ind_mask]  # [count_ind]

        _sample_subspace_n_scatter_(
            sample_joint=False,
            x_range=x_range_ind,
            x_dim_indices=dims_ind,
            chunk_index_slice=index_slice_ind,
            chunks=chunks,
            chunk_mask=chunk_mask,
            d=num_subspace_points,
            grid=use_grid_sampling,
        )

    # Final checks: no NaNs or infs in chunks
    assert not torch.isnan(chunks).any(), f"chunks has NaNs: {chunks}"
    assert not torch.isinf(chunks).any(), f"chunks has infs: {chunks}"

    return chunks, chunk_mask


def get_all_sample_from_chunks_efficient(
    chunks: Tensor, chunk_mask: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    """Get samples on full space from chunks (efficient version).

    NOTE that chunks are combined in order, i.e., pattern exists, be careful with slices

    Details:
        Create a grid of indices for all chunks
        Gather chunk samples for each chunk as follows:
            Input: chunk, [d, dx_max]
            Input: coord_grid, [M]
            Output: chunk_sample shape [M, dx_max],
                where chunk_sample[m] += chunks[coord_grid[m]]

    Args:
        chunks: Data chunks, shape [d, dx_max]
        chunk_mask: Mask for valid dims in each chunk, shape [n_chunks, dx_max]
            No overlapping valid dims between chunks

    Returns:
        samples: shape [M, dx_max]
        mask: shape [dx_max]
        chunk_coord_grid: shape [M, n_chunks]
    """
    assert chunk_mask.shape[-1] == chunks.shape[-1]
    assert chunk_mask.ndim == 2 and chunks.ndim == 2

    valid_dim_counts = chunk_mask.int().sum(dim=0)  # [dx_max]
    if not torch.all((valid_dim_counts <= 1)):
        raise ValueError("`chunks have overlapping valid dims.")

    d, dx_max = chunks.shape
    n_chunks, _ = chunk_mask.shape

    # Each chunk has its own index range: [0, 1, ..., d-1]
    chunk_index_ranges = [
        torch.arange(d, device=chunks.device, dtype=torch.long) for _ in range(n_chunks)
    ]

    # Create a grid of chunk indices: [M x [chunk_1_indice, ..., chunk_n_indice]]
    chunk_coord_grid = torch.cartesian_prod(*chunk_index_ranges)
    M = d**n_chunks

    # NOTE `cartesian_prod` returns a flatten vector when n=1
    chunk_coord_grid = chunk_coord_grid.view(-1, n_chunks)
    assert chunk_coord_grid.shape == (M, n_chunks)

    samples = torch.zeros(M, dx_max, device=chunks.device, dtype=chunks.dtype)
    for i in range(n_chunks):
        # Indices in the i-th chunk: [M]
        coord_grid = chunk_coord_grid[:, i]

        # Mask for valid dims in the i-th chunk
        mask = chunk_mask[i]
        mask_expanded = repeat(mask, "dx -> d dx", d=d)

        # Gather valid dims from the i-th chunk and add to samples
        chunk = chunks * mask_expanded  # [d, dx_max]
        samples += chunk[coord_grid]  # [M, dx_max]

        # Get mask over full space by taking OR: [dx_max]
    mask = chunk_mask.any(dim=0)

    return samples, mask, chunk_coord_grid


def generate_sobol_samples(x_range: Tensor, num_datapoints: int, grid: bool) -> Tensor:
    """Sample from a Sobol sequence within the given range.

    Args:
        x_range: Range to sample from, [D, 2]
        num_datapoints: Number of datapoints to sample
        grid: Whether to sample from a grid over ranges, or random locations

    Returns:
        x: [num_datapoints, D]
    """
    assert x_range.ndim == 2 and x_range.shape[1] == 2
    dim_num = x_range.shape[0]

    def _get_num_sobol_samples(grid, num_datapoints):
        if grid and num_datapoints > 1:
            # NOTE skipping single point case to avoid fixed location
            return num_datapoints
        else:
            return max(GRID_SIZE, num_datapoints)

    n = _get_num_sobol_samples(grid, num_datapoints)

    # Generate Sobol samples within [0,1]^dim_num
    sobol_samples = torch.from_numpy(
        sobol_seq.i4_sobol_generate(dim_num=dim_num, n=n)
    ).to(device=x_range.device, dtype=torch.float32)

    # Permute Sobol samples and extract needed number
    perm_indices = torch.randperm(n, device=x_range.device)
    samples_perm = sobol_samples[perm_indices]
    samples_perm = samples_perm[:num_datapoints]

    # Rescale to the given range
    x_min, x_max = x_range[:, 0], x_range[:, 1]
    x = x_min + (x_max - x_min) * samples_perm
    x = torch.max(torch.min(x, x_max), x_min)

    return x


def sample_factorized_subspaces(
    d: int,
    x_mask: Tensor,
    input_bounds: FloatListOrNestedOrTensor,
    use_grid_sampling: bool = False,
    use_factorized_policy: bool = False,
) -> Tuple[Tensor, Tensor]:
    """Sample in the factorized input space."""
    max_x_dim = x_mask.shape[-1]

    # Get q_mask: [max_x_dim]
    q_mask = get_q_mask(
        x_dim=max_x_dim,
        use_factorized_policy=use_factorized_policy,
        device=x_mask.device,
    )

    # Sample chunks: [d, max_x_dim] and [n, max_x_dim]
    chunks, chunk_mask = sample_factorized_space_efficient(
        x_range=input_bounds,
        x_dim_mask=x_mask,
        q_dim_mask=q_mask,
        num_subspace_points=d,
        use_grid_sampling=use_grid_sampling,
    )

    return chunks, chunk_mask


def sample_full_inputs_from_subspaces(
    d: int,
    x_mask: Tensor,
    input_bounds: FloatListOrNestedOrTensor,
    use_grid_sampling: bool = False,
    use_factorized_policy: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
    """Sample inputs from given input bounds.

    Returns:
        x: [m, dx]
        chunks: [d, dx]
        chunk_mask: [n, dx]
    """
    chunks, chunk_mask = sample_factorized_subspaces(
        d=d,
        x_mask=x_mask,
        input_bounds=input_bounds,
        use_grid_sampling=use_grid_sampling,
        use_factorized_policy=use_factorized_policy,
    )

    # Combine in full space: [m, max_x_dim]
    x = get_all_sample_from_chunks_efficient(chunks, chunk_mask)[0]

    return x, chunks, chunk_mask


def get_sample_from_chunk_indices(chunks: Tensor, chunk_indices: Tensor) -> Tensor:
    """Get sample in full space by combining `chunks` at `chunk_indices`.

    Args:
        chunks: shape [B, N, d, dx_max]
        chunk_indices: shape [B, num_index, N]

    Returns:
        full_samples: shape [B, num_index, dx_max]

    Notes on fancy indexing:
        selected[b, num, i] = chunks[
            batch_idx[b, num, i],
            chunk_idx[b, num, i],
            chunk_indices[b, num, i],
        ]
        Aim: selected[b, i, n] = chunks[b, n, chunk_indices[b, i, n]]
        Therefore: batch_idx[b, i, n] = b, chunk_idx[b, i, n] = n

    """
    B, num_index, N = chunk_indices.shape

    # Prepare indexing: [B, num_index, N]
    batch_idx = torch.arange(B, device=chunks.device, dtype=torch.long)
    chunk_idx = torch.arange(N, device=chunks.device, dtype=torch.long)

    batch_idx = repeat(batch_idx, "b -> b num_index n", num_index=num_index, n=N)
    chunk_idx = repeat(chunk_idx, "n -> b num_index n", b=B, num_index=num_index)

    # Gather chunks based on chunk_indices: [B, N, d, dx_max] -> [B, num_index, N, dx_max]
    chunk_samples = chunks[batch_idx, chunk_idx, chunk_indices]

    # Sum over chunks to get the combined output: [B, num_index, dx_max]
    full_samples = chunk_samples.sum(dim=2)
    return full_samples


def get_sample_indices_from_chunk_indices(
    chunk_indices: Tensor, n: int, d: int
) -> Tensor:
    """Get sample indices in full space from chunk indices.
    Args:
        chunk_indices [..., n], n: number of chunks, d: number of points in each subspace

    Returns:
        full_indices [..., 1]
    """
    # Each index vector as a base-D number: [d0, d1, d2] -> d0 * D^{n-1} + d1 * d^{n-2} + d2 * d^0
    powers = d ** torch.arange(n - 1, -1, -1, device=chunk_indices.device)  # [N]

    # [..., n] -> [..., 1]
    full_indices = (chunk_indices * powers).sum(dim=-1, keepdim=True)

    return full_indices


def get_num_subspace_points(
    x_dim: int,
    use_factorized_policy: bool = True,
    epoch: Optional[int] = None,
    total_epochs: Optional[int] = None,
) -> int:
    """Get number of points in each subspace given number of dims and policy mode."""
    assert 0 < x_dim <= 4, "x_dim must be in [1, 2, 3, 4]"

    if use_factorized_policy:
        num_categories_dict = {
            1: 128,
            2: 32,
            3: 32,
            4: 32,
        }
    else:

        num_categories_dict = {
            1: 100,
            2: 200,
            3: 300,
            4: 400,
        }

    return num_categories_dict[x_dim]
