# preprocess_audio.py
# Audio preprocessing: energy-based VAD and framing / mel spectrogram extraction.

from typing import Tuple
import numpy as np
import torch

try:
    import torchaudio
    from torchaudio.transforms import MelSpectrogram, Resample
except Exception:
    torchaudio = None
    MelSpectrogram = None
    Resample = None

from config import AUDIO_SAMPLE_RATE, AUDIO_FRAME_MS, AUDIO_HOP_MS

def apply_vad(waveform: torch.Tensor, sample_rate: int = AUDIO_SAMPLE_RATE, frame_ms: int = AUDIO_FRAME_MS, hop_ms: int = AUDIO_HOP_MS, energy_threshold: float = 1e-4) -> torch.Tensor:
    """
    Energy-based VAD: drop frames whose RMS energy below threshold.
    waveform: (1, T)
    Returns waveform cropped to voiced regions (concatenated).
    """
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    assert waveform.dim() == 2  # (1, T)
    frame_len = int(sample_rate * frame_ms / 1000)
    hop_len = int(sample_rate * hop_ms / 1000)
    x = waveform[0].numpy()
    frames = []
    for start in range(0, len(x), hop_len):
        end = start + frame_len
        if end > len(x):
            frame = x[start:]
        else:
            frame = x[start:end]
        if frame.size == 0:
            continue
        rms = np.sqrt(np.mean(frame ** 2))
        if rms >= energy_threshold:
            frames.append(frame)
    if not frames:
        # nothing voiced: return original waveform
        return waveform
    voiced = np.concatenate(frames)
    return torch.from_numpy(voiced.astype("float32")).unsqueeze(0)

def frame_audio(waveform: torch.Tensor, sample_rate: int = AUDIO_SAMPLE_RATE, frame_ms: int = AUDIO_FRAME_MS, hop_ms: int = AUDIO_HOP_MS) -> torch.Tensor:
    """
    Return framed audio as (num_frames, frame_len) tensor.
    """
    if waveform.dim() == 2:
        waveform = waveform[0]
    frame_len = int(sample_rate * frame_ms / 1000)
    hop_len = int(sample_rate * hop_ms / 1000)
    frames = []
    x = waveform.numpy()
    for start in range(0, len(x) - frame_len + 1, hop_len):
        frames.append(x[start:start + frame_len])
    if not frames:
        # pad if too short
        pad = np.zeros(frame_len, dtype="float32")
        pad[:len(x)] = x
        frames = [pad]
    return torch.from_numpy(np.stack(frames)).float()  # (F, frame_len)

def extract_mel_spectrogram(waveform: torch.Tensor, sample_rate: int = AUDIO_SAMPLE_RATE, n_mels: int = 80) -> torch.Tensor:
    """
    Compute mel spectrogram (power or magnitude) using torchaudio if available.
    Returns (n_mels, time)
    """
    if torchaudio is not None and MelSpectrogram is not None:
        # ensure shape (1, T)
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        transform = MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        mel = transform(waveform)  # (1, n_mels, time)
        return mel.squeeze(0)
    # fallback: simple STFT-based mel-like features using numpy
    import numpy as np
    x = waveform.numpy().squeeze()
    # very simple: compute magnitude STFT and reduce to n_mels bands
    win = 512
    hop = 256
    stft = np.abs(np.fft.rfft(np.lib.stride_tricks.sliding_window_view(x, win)[::hop] * np.hanning(win), axis=-1))
    # reduce to n_mels by grouping frequencies
    stft = stft.T  # (freq_bins, time_frames)
    freq_bins = stft.shape[0]
    groups = np.array_split(np.arange(freq_bins), n_mels)
    mel = np.stack([stft[g].mean(axis=0) if len(g) else np.zeros(stft.shape[1]) for g in groups])
    return torch.from_numpy(mel.astype("float32"))
