"""
Audio Processing Utilities

音频加载、保存和预处理工具
"""

import logging
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import torch

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def load_audio(path: str, sr: int = 16000) -> torch.Tensor:
    """
    加载音频文件
    
    Args:
        path: 音频文件路径
        sr: 目标采样率
        
    Returns:
        音频张量 (1, T)
    """
    try:
        import librosa
        audio, _ = librosa.load(path, sr=sr, mono=True)
        return torch.from_numpy(audio).unsqueeze(0).float()
    except ImportError:
        pass
    
    try:
        import torchaudio
        audio, orig_sr = torchaudio.load(path)
        if orig_sr != sr:
            audio = torchaudio.transforms.Resample(orig_sr, sr)(audio)
        return audio.mean(dim=0, keepdim=True).float()
    except ImportError:
        pass
    
    logger.warning(f"无法加载音频 {path}，返回随机数据")
    return torch.randn(1, sr * 3)


def save_audio(audio: torch.Tensor, path: str, sr: int = 16000):
    """
    保存音频文件
    
    Args:
        audio: 音频张量
        path: 保存路径
        sr: 采样率
    """
    try:
        import torchaudio
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)
        torchaudio.save(path, audio.cpu(), sr)
        return
    except ImportError:
        pass
    
    try:
        import soundfile as sf
        sf.write(path, audio.cpu().numpy().squeeze(), sr)
        return
    except ImportError:
        pass
    
    logger.warning(f"无法保存音频到 {path}")


class AudioProcessor:
    """
    音频处理器
    
    提供音频预处理和后处理功能
    """
    
    def __init__(self, sample_rate: int = 16000, n_mels: int = 80):
        """
        初始化
        
        Args:
            sample_rate: 采样率
            n_mels: Mel 频带数
        """
        self.sample_rate = sample_rate
        self.n_mels = n_mels

    def load(self, path: str) -> torch.Tensor:
        """加载音频"""
        return load_audio(path, self.sample_rate)

    def save(self, audio: torch.Tensor, path: str):
        """保存音频"""
        save_audio(audio, path, self.sample_rate)

    def compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
        """
        计算 Mel 频谱
        
        Args:
            audio: 音频张量 (B, T)
            
        Returns:
            Mel 频谱 (B, n_mels, T')
        """
        try:
            import torchaudio
            mel_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=1024,
                hop_length=256,
                n_mels=self.n_mels
            )
            return mel_transform(audio)
        except ImportError:
            # 模拟 Mel 频谱
            T = audio.shape[-1] // 256
            return torch.randn(audio.shape[0], self.n_mels, T)

    def trim_silence(self, audio: torch.Tensor, threshold: float = 0.01) -> torch.Tensor:
        """
        去除静音
        
        Args:
            audio: 音频张量
            threshold: 阈值
            
        Returns:
            去除静音后的音频
        """
        if audio.dim() == 1:
            audio = audio.unsqueeze(0)
        
        # 简单阈值裁剪
        energy = audio.abs().squeeze()
        non_silent = energy > threshold
        
        if non_silent.any():
            start = non_silent.nonzero()[0].item()
            end = non_silent.nonzero()[-1].item() + 1
            return audio[:, start:end]
        
        return audio

    def normalize(self, audio: torch.Tensor) -> torch.Tensor:
        """
        归一化音频
        
        Args:
            audio: 音频张量
            
        Returns:
            归一化后的音频
        """
        max_val = audio.abs().max()
        if max_val > 0:
            return audio / max_val * 0.95
        return audio

    def resample(self, audio: torch.Tensor, orig_sr: int, target_sr: int = None) -> torch.Tensor:
        """
        重采样
        
        Args:
            audio: 音频张量
            orig_sr: 原始采样率
            target_sr: 目标采样率
            
        Returns:
            重采样后的音频
        """
        target_sr = target_sr or self.sample_rate
        
        if orig_sr == target_sr:
            return audio
        
        try:
            import torchaudio
            resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
            return resampler(audio)
        except ImportError:
            # 简单插值
            ratio = target_sr / orig_sr
            new_len = int(audio.shape[-1] * ratio)
            return torch.nn.functional.interpolate(
                audio.unsqueeze(0), size=new_len, mode='linear', align_corners=False
            ).squeeze(0)
