import math
import torch


def _get_sinc_resample_kernel(
    orig_freq, new_freq, gcd, lowpass_filter_width = 6, rolloff = 0.99, resampling_method = "sinc_interp_hann",
):
    orig_freq = int(orig_freq) // gcd
    new_freq = int(new_freq) // gcd

    if lowpass_filter_width <= 0:
        raise ValueError("Low pass filter width should be positive.")
    base_freq = min(orig_freq, new_freq)
    base_freq *= rolloff

    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    idx_dtype = torch.float64

    idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype)[None, None] / orig_freq

    t = torch.arange(0, -new_freq, -1)[:, None, None] / new_freq + idx
    t *= base_freq
    t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)

    window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2    
    t *= math.pi

    scale = base_freq / orig_freq
    kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
    kernels *= window * scale

    kernels = kernels.to(dtype=torch.float32)

    return kernels, width

def _apply_sinc_resample_kernel(
    waveform, orig_freq, new_freq,
):
    gcd = math.gcd(int(orig_freq), int(new_freq))
    kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd)

    orig_freq = int(orig_freq) // gcd
    new_freq = int(new_freq) // gcd

    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled
