"""Utilities for padding and restoring dimensions in tensors."""

import torch
from torch import Tensor
from typing import Optional, Tuple


def take_along_valid_dims(data: Tensor, mask: Tensor, dim: int = -1) -> Tensor:
    """Select valid dims along `dim` of `data` based on `mask`.

    Args:
        data: Data with data.shape[dim] == d_max
        mask: Mask to indicate valid dims, [d_max] or [B, d_max]
        dim: Dimension along which to select

    Returns:
        data_valid: data_valid.shape[dim] == d_valid
    """
    if mask is None:
        return data

    assert mask.dtype == torch.bool
    assert mask.shape[-1] == data.shape[dim]

    if mask.ndim == 1:
        valid_indices = mask.nonzero(as_tuple=True)[0]
        return torch.index_select(data, dim, valid_indices)
    elif mask.ndim == 2:  # Batched mask
        B = mask.shape[0]
        assert data.shape[0] == B, f"Batch mismatch: {data.shape[0]} != {B}"
        valid_indices = [m.nonzero(as_tuple=True)[0] for m in mask]
        valid_data_list = [
            torch.index_select(d, dim, idx) for d, idx in zip(data, valid_indices)
        ]
        return torch.stack(valid_data_list, dim=0)  # [batch, d_valid]
    else:
        raise ValueError(f"Mask ndim {mask.ndim} not supported.")


def restore_full_dims(data: Tensor, mask: Tensor, dim: int = -1) -> Tensor:
    """Restore full dimensions of data along dim given a mask. 
    - If data or mask is None, return data directly.
    - If data dim is already full, return data directly.
    - Otherwise fill zeros for invalid dims.
    
    Examples:
        data: [[[1.0, 2.0]]], mask: [True, False, True], data_full: [[[1.0, 0.0, 2.0]]]

    Args:
        data: Data with data.shape[dim] == valid_dim, where valid_dim = mask.sum()
        mask: Mask to indicate valid dims, [max_dim]
        dim: Dimension along which to restore

    Returns:
        data_full: Data with data_full.shape[dim] == max_dim
    """
    if data is None or mask is None:
        return data

    assert mask.dtype == torch.bool
    if mask.ndim == 1:
        valid_dim = mask.int().sum().item()
        max_dim = mask.shape[-1]

        if data.shape[dim] != valid_dim:
            raise ValueError(
                f"Mismatch between data and mask: data.shape[{dim}]={data.shape[dim]} != {valid_dim}"
            )

        if valid_dim == max_dim:
            return data

        desired_shape = list(data.shape)
        desired_shape[dim] = max_dim
        data_full = torch.zeros(desired_shape, device=data.device, dtype=data.dtype)
        valid_indices = mask.nonzero(as_tuple=True)[0]
        data_full.index_copy_(dim, valid_indices, data)

        return data_full
    else:
        raise NotImplementedError("Only support 1-dim mask for now.")


def _generate_top_k_mask(k_batch: Tensor, max_dim: int, device: str) -> Tensor:
    """Generate a mask of shape [B, max_dim] with top-k dims as True."""
    B = k_batch.shape[0]
    dim_indices = torch.arange(max_dim, dtype=torch.long, device=device)
    dim_indices = dim_indices.unsqueeze(0).expand(B, -1)
    mask = dim_indices < k_batch.unsqueeze(1)
    return mask


def _get_random_permutation(B, dim, device) -> Tensor:
    random_indices = torch.argsort(torch.rand((B, dim), device=device), dim=-1)
    return random_indices


def gather_by_indices(data, indices) -> Tensor:
    """Gather data of shape [B, N, D] by indices of shape [B, D] along last dimension."""
    if indices is None:
        return data

    indices_exp = indices.unsqueeze(1).expand_as(data)
    return torch.gather(data, dim=-1, index=indices_exp)


def generate_dim_mask(
    max_dim: int,
    device: str,
    k: Optional[Tensor | int] = None,
    dim_scatter_mode: str = "random_k",
) -> Tuple[Tensor, Tensor]:
    """Generate mask given maximum dimension size and specified number of valid dimensions.

    Args:
        max_dim: Maximum dimension size
        device: Computational device
        k: Number of valid dimensions, None | int | [B]
        dim_scatter_mode: ["random_k", "top_k"]. Default as "random_k"

    Returns:
        mask: Generated mask, [max_dim] | [max_dim] | [B, max_dim]
        valid_dim_indices: Optional valid dimension indices of shape [B, max_dim] when dim_scatter_mode is "random_k".
    """
    if dim_scatter_mode not in ["random_k", "top_k"]:
        raise ValueError(f"Invalid dim_scatter_mode: {dim_scatter_mode}.")

    valid_dim_indices = None
    if k is None:
        mask = torch.ones((max_dim,), dtype=torch.bool, device=device)
    else:
        if isinstance(k, int):
            make_single_mask = True
            k_batch = torch.tensor([k], dtype=torch.long, device=device)  # [1]
        else:
            make_single_mask = False
            if not (isinstance(k, Tensor) and k.ndim == 1):
                raise ValueError(f"Invalid k: {k}. Expected k to be a 1-dim Tensor.")

            k_batch = k.to(device=device)  # [B]

        B = k_batch.shape[0]
        mask = _generate_top_k_mask(k_batch=k_batch, max_dim=max_dim, device=device)

        if dim_scatter_mode == "random_k":
            valid_dim_indices = _get_random_permutation(B=B, dim=max_dim, device=device)
            mask = torch.gather(mask, dim=-1, index=valid_dim_indices)

        if make_single_mask:
            mask = mask.squeeze(0)  # [1, max_dim] -> [max_dim]

    mask.requires_grad_(False)
    return mask, valid_dim_indices


def get_q_mask(x_dim, use_factorized_policy: bool, device: str) -> Tensor:
    """Create mask that indicates dimensions to be sampled independently.

    Examples:
        q_mask = [False, False, False]: a single query set over the entire search space
        q_mask = [True, True, True]: query sets for each dim
        q_mask = [False, False, True]: one query set over dim 0 and 1, one query set for dim 2

    Args:
        x_dim: input dimension
        use_factorized_policy: factorize the input space into subspaces or not.
        device: computational device

    Returns:
        q_mask: shape [x_dim]
    """
    if use_factorized_policy:
        q_mask = torch.ones(x_dim, device=device, dtype=torch.bool)
    else:
        q_mask = torch.zeros(x_dim, device=device, dtype=torch.bool)

    return q_mask
