import librosa
import numpy as np
import torch
import torchaudio
from torch import Tensor
from torch.nn import functional as F


def add(
    base_audio: Tensor,
    target_audio: Tensor,
    loc: str | float | None = None,
    base_sr: int = 48_000,
):
    if isinstance(loc, float):
        left_pad = round(loc * base_sr)
    elif isinstance(loc, str):
        match loc:
            case "start":
                left_pad = 0
            case "middle":
                left_pad = (len(base_audio) - len(target_audio)) // 2
            case "end":
                left_pad = len(base_audio) - len(target_audio)
            case _:
                raise NotImplementedError()
    else:
        left_pad = 0
    right_pad = len(base_audio) - len(target_audio) - left_pad
    target_audio = F.pad(target_audio, (left_pad, right_pad), mode="constant")
    return base_audio, base_audio + target_audio


def drop(base_audio: Tensor, target_audio: Tensor):
    return base_audio + target_audio, base_audio


def replace(base_audio: Tensor, target_audio_1: Tensor, target_audio_2: Tensor):
    return base_audio + target_audio_1, base_audio + target_audio_2


def inpaint(base_audio: Tensor):
    rng = np.random.default_rng()
    audio_length = base_audio.shape[1]
    random_length = round(rng.uniform(0.1, 0.9) * audio_length)
    random_start = round(rng.integers(audio_length - random_length))
    target_audio = base_audio.clone()
    target_audio[:, random_start : random_start + random_length] = 0
    return target_audio, base_audio


def superres(base_audio: Tensor, base_sr: int = 48_000, target_sr: int = 24_000):
    target_audio = torchaudio.functional.resample(
        base_audio, orig_freq=base_sr, new_freq=target_sr
    )
    target_audio = torchaudio.functional.resample(
        target_audio, orig_freq=target_sr, new_freq=base_sr
    )
    return target_audio, base_audio


def noise(base_audio: Tensor, scale: float = 1e-2):
    target_audio = base_audio + torch.randn_like(base_audio) * scale
    return target_audio, base_audio


def pitch(base_audio: Tensor, base_sr: int = 48_000, steps: int = 4):
    """
    print("inside before pitch", base_audio.shape, base_sr, steps)
    target_audio = torchaudio.functional.pitch_shift(
        base_audio, sample_rate=base_sr, n_steps=steps
    )
    print("after pitch", target_audio.shape)
    """
    np_audio = base_audio.cpu().numpy()

    # Process each channel separately (librosa only works on mono)
    shifted = []
    for i in range(np_audio.shape[0]):
        shifted_channel = librosa.effects.pitch_shift(
            np_audio[i], sr=base_sr, n_steps=steps
        )
        shifted.append(shifted_channel)

    # Stack and convert back to torch.Tensor
    shifted_audio = torch.from_numpy(np.stack(shifted)).to(base_audio.dtype)

    return base_audio, shifted_audio


def speed(base_audio: Tensor, factor: float = 1.5):
    n_fft = 2048
    stft = torch.stft(
        base_audio,
        n_fft=n_fft,
        window=torch.hann_window(window_length=n_fft),
        return_complex=True,
    )
    hop_length = n_fft // 4
    sr = 2 * torch.pi
    fft_freq = torch.arange(n_fft // 2 + 1) * sr / n_fft
    phase_advance = hop_length * fft_freq.unsqueeze(dim=1)
    stft_stretch = torchaudio.functional.phase_vocoder(
        stft, rate=factor, phase_advance=phase_advance
    )
    target_audio = torch.istft(
        stft_stretch, n_fft=n_fft, window=torch.hann_window(window_length=n_fft)
    )

    return base_audio, target_audio


def high_pass(base_audio: Tensor, base_sr: int = 48_000, cutoff_freq: int = 1_000):
    target_audio = torchaudio.functional.highpass_biquad(
        base_audio, sample_rate=base_sr, cutoff_freq=cutoff_freq
    )
    return base_audio, target_audio


def low_pass(base_audio: Tensor, base_sr: int = 48_000, cutoff_freq: int = 8_000):
    target_audio = torchaudio.functional.lowpass_biquad(
        base_audio, sample_rate=base_sr, cutoff_freq=cutoff_freq
    )
    return base_audio, target_audio


def swap(base_audio: Tensor, target_audio: Tensor):
    return torch.cat((base_audio, target_audio), dim=1), torch.cat(
        (target_audio, base_audio), dim=1
    )


def loop(base_audio: Tensor, num_loop: int = 3):
    return base_audio, torch.cat([base_audio] * num_loop, dim=1)
