import torch 

def dst1(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Discrete Sine Transform, Type-I (DST-I), implemented via FFT (CUDA-compatible).

    Matches SciPy: scipy.fftpack.dst(x, type=1, norm='backward') on the chosen axis.

    Definition (for length N):
        y_k = 2 * sum_{n=1..N} x_n * sin(pi * k * n / (N+1)),  k=1..N

    Args:
        x: real-valued tensor
        dim: dimension along which to apply DST-I

    Returns:
        real-valued tensor of same shape as x
    """
    if x.is_complex():
        raise TypeError("dst1 expects a real-valued tensor.")

    x = x.movedim(dim, -1)
    N = x.shape[-1]

    # Build odd extension: v = [0, x_1..x_N, 0, -x_N..-x_1], length 2(N+1)
    v = torch.zeros(*x.shape[:-1], 2 * (N + 1), device=x.device, dtype=x.dtype)
    v[..., 1 : N + 1] = x
    v[..., N + 2 :] = -torch.flip(x, dims=(-1,))

    # FFT on extended signal
    # For fp16/bf16: promote to fp32 for FFT stability, then cast back.
    work_dtype = x.dtype
    if work_dtype in (torch.float16, torch.bfloat16):
        v_fft_in = v.to(torch.float32)
        V = torch.fft.fft(v_fft_in.to(torch.complex64), dim=-1)
        y = (-V.imag[..., 1 : N + 1]).to(work_dtype)
    else:
        V = torch.fft.fft(v.to(torch.complex128 if work_dtype == torch.float64 else torch.complex64), dim=-1)
        y = (-V.imag[..., 1 : N + 1]).to(work_dtype)

    return y.movedim(-1, dim)




def idst1(y: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Inverse Discrete Sine Transform, Type-I (IDST-I), consistent with dst1().

    For DST-I with 'backward' normalization (SciPy default), the transform is
    self-adjoint up to a constant:
        dst1(dst1(x)) = 2*(N+1) * x

    Therefore:
        idst1(y) = dst1(y) / (2*(N+1))

    Args:
        y: real-valued tensor
        dim: dimension along which to apply IDST-I

    Returns:
        real-valued tensor of same shape as y
    """
    y = y.movedim(dim, -1)
    N = y.shape[-1]
    x = dst1(y, dim=-1) / (2.0 * (N + 1))
    return x.movedim(-1, dim)



def torch_to_idst(x, D):
    return idst1((D * x).cpu().numpy(), type=1, norm="backward").T


