import torch
from torch_dct import dct, dct_2d

def truncated_dct(x: torch.Tensor, 
                  num_coeffs: int,
                  batch_size: int = 128,
                  to_cpu: bool = False
                  ) -> torch.Tensor:
    """
    Apply 1D DCT (type-II) to a batch of vectors and return only the top `num_coeffs` coefficients.

    Args:
        x: Tensor of shape (B, D) representing B samples of D-dimensional vectors.
        num_coeffs: Number of DCT coefficients to keep.

    Returns:
        Tensor of shape (B, num_coeffs) with the top DCT coefficients.
    """
    if x.dim() != 2:
        raise ValueError(f"Expected input of shape (B, D), but got {x.shape}")
    if num_coeffs > x.shape[1]:
        raise ValueError(f"num_coeffs ({num_coeffs}) cannot be greater than D ({x.shape[1]})")

    # Apply DCT along the last dimension (dim=-1) in batches of size `batch_size`
    dct_transformed = []
    for i in range(0, x.shape[0], batch_size):
        batch = x[i:i + batch_size]
        dct_batch = dct(batch, norm='ortho')
        if to_cpu:
            dct_batch = dct_batch.cpu()
        dct_transformed.append(dct_batch)
    dct_transformed = torch.cat(dct_transformed, dim=0)  # Concatenate the batches back together

    # Truncate to keep only first `num_coeffs` coefficients
    return dct_transformed[:, :num_coeffs] # (B, num_coeffs)

def truncated_2d_dct(x: torch.Tensor, 
                  num_coeffs: int,
                  batch_size: int = 128,
                  to_cpu: bool = False
                  ) -> torch.Tensor:
    """
    Apply 2D DCT to a batch of vectors and return only the top `num_coeffs` coefficients.

    Args:
        x: Tensor of shape (B, C, H, W) representing B samples of (C, H, W)-dimensional images.
        num_coeffs: Number of DCT coefficients PER DIMENSION to keep.

    Returns:
        Tensor of shape (B, C, num_coeffs, num_coeffs) with the top DCT coefficients.
    """
    if x.dim() != 4:
        raise ValueError(f"Expected input of shape (B, C, H, W), but got {x.shape}")
    if num_coeffs > x.shape[-1] or num_coeffs > x.shape[-2]:
        raise ValueError(f"num_coeffs ({num_coeffs}) cannot be greater than H ({x.shape[-2]}) or W ({x.shape[-1]}).")

    # Apply DCT along last dimension (dim=-1)
    # dct_transformed = dct(x, norm='ortho')

    # Apply DCT in batches of size `batch_size` and truncate to `num_coeffs` along last two dimensions
    dct_transformed = []
    for i in range(0, x.shape[0], batch_size):
        batch = x[i:i + batch_size]
        dct_batch = dct_2d(batch, norm='ortho')
        if to_cpu:
            dct_batch = dct_batch.cpu()
        dct_transformed.append(dct_batch)
    dct_transformed = torch.cat(dct_transformed, dim=0)  # Concatenate the batches back together

    # Truncate to keep only first `num_coeffs` coefficients
    return dct_transformed[:, :, :num_coeffs, :num_coeffs] # (B, C, num_coeffs, num_coeffs)