# audio_encoder.py
# Lightweight AudioEncoder wrapper.
# Provides encode_continuous -> frame-level embeddings (tensor)
# and encode_discrete -> discrete token ids (list or tensor)

from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import numpy as np

from preprocess_audio import extract_mel_spectrogram
from config import AUDIO_EMBED_DIM

# Optional sklearn KMeans for discrete tokenization
try:
    from sklearn.cluster import KMeans
except Exception:
    KMeans = None

class SimpleAudioEmbedder(nn.Module):
    """
    Minimal CNN-based embedder: input mel-spectrogram (n_mels, T) -> sequence of embeddings (T', D)
    """
    def __init__(self, in_mels: int = 80, emb_dim: int = AUDIO_EMBED_DIM):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_mels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, emb_dim, kernel_size=3, padding=1),
            nn.AdaptiveAvgPool1d(0 + 1)  # we will handle time dimension manually
        )
        # note: we'll apply conv per time, so use a simple linear projection fallback
        self.proj = nn.Linear(in_mels, emb_dim)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        mel: (n_mels, time) or (batch, n_mels, time)
        returns: (time, emb_dim) or (batch, time, emb_dim)
        """
        if mel.dim() == 2:
            mel = mel.unsqueeze(0)  # (1, n_mels, time)
        b, n_mels, t = mel.shape
        # project frame-wise
        mel_t = mel.permute(0, 2, 1).contiguous()  # (b, time, n_mels)
        emb = self.proj(mel_t)  # (b, time, emb_dim)
        return emb  # (b, time, emb_dim)

class AudioEncoder:
    """
    High-level audio encoder that supports continuous embeddings and a discrete tokenization method.
    For a production system, swap in whisper / soundstream implementations.
    """
    def __init__(self, device: Optional[str] = None, emb_dim: int = AUDIO_EMBED_DIM):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.embedder = SimpleAudioEmbedder(emb_dim=emb_dim).to(self.device)
        self._kmeans_model = None

    def encode_continuous(self, waveform: torch.Tensor, sample_rate: int = 16000) -> torch.Tensor:
        """
        Return continuous frame-level embeddings: (time, emb_dim) as numpy or torch tensor.
        waveform: (1, T) torch tensor
        """
        mel = extract_mel_spectrogram(waveform, sample_rate=sample_rate)  # (n_mels, time)
        mel = mel.to(self.device)
        with torch.no_grad():
            emb = self.embedder(mel)  # (1, time, emb_dim)
        return emb.squeeze(0).cpu()  # (time, emb_dim) on CPU

    def encode_discrete(self, waveform: torch.Tensor, n_tokens: int = 64, sample_rate: int = 16000) -> List[int]:
        """
        Convert audio into discrete token ids via clustering frame embeddings.
        Returns a list of token ids (length ~= number of frames).
        """
        emb = self.encode_continuous(waveform, sample_rate=sample_rate)  # (time, emb_dim) tensor
        X = emb.numpy()
        k = min(n_tokens, max(2, X.shape[0] // 2))
        if KMeans is not None:
            # fit kmeans (fast prototype)
            km = KMeans(n_clusters=k, random_state=0)
            labels = km.fit_predict(X)
            # store model for later use (optional)
            self._kmeans_model = km
            return labels.tolist()
        # fallback: coarse quantization using PCA-style rounding
        means = X.mean(axis=0, keepdims=True)
        dif = X - means
        # project onto first component
        comp = dif[:, 0] if dif.shape[1] > 0 else np.zeros(X.shape[0])
        # normalize and quantize to k bins
        if comp.max() == comp.min():
            bins = np.zeros_like(comp, dtype=int)
        else:
            norm = (comp - comp.min()) / (comp.ptp() + 1e-9)
            bins = (norm * (k-1)).astype(int)
        return bins.tolist()
