import numpy as np
import torch
import librosa
import torch.nn.functional as F
from audio_encoder import AudioEncoder

class AudioProcessor:
    def __init__(self, device='cpu'):
        self.sr = 16000
        self.n_fft = 800
        self.hop_length = 200
        self.n_mels = 80
        self.window_size = 16
        self.device = device

        self.encoder = AudioEncoder().to(device)
        self.encoder.eval()

    def load_wav(self, filepath):
        wav, _ = librosa.load(filepath, sr=self.sr)

        audio_duration = len(wav) / self.sr
        num_frames = int(audio_duration * 25)  
        return wav, num_frames

    def compute_mel(self, wav):
        """计算梅尔频谱图特征"""
        mel = librosa.feature.melspectrogram(y=wav, sr=self.sr,
                                             n_fft=self.n_fft,
                                             hop_length=self.hop_length,
                                             n_mels=self.n_mels,
                                             power=1.0)
        mel = librosa.power_to_db(mel, ref=np.max)

        mel = (mel - mel.min()) / (mel.max() - mel.min() + 1e-8)
        return mel  # shape: (80, T)
    
    def _get_mel_for_frame(self, frame_idx, mel):
        frame_time = frame_idx * (1 / 25) 
        audio_center_idx = int(frame_time * self.sr)
        mel_center_idx = int(audio_center_idx / self.hop_length)
        mel_start_idx = mel_center_idx - 8
        mel_end_idx = mel_center_idx + 8

        if mel_start_idx < 0:
            mel_end_idx = min(mel_end_idx - mel_start_idx, mel.shape[1])
            mel_start_idx = 0
        elif mel_end_idx > mel.shape[1]:
            mel_start_idx = max(0, mel_start_idx - (mel_end_idx - mel.shape[1]))
            mel_end_idx = mel.shape[1]

        mel_sample = mel[:, mel_start_idx:mel_end_idx]

        if mel_sample.shape[1] < 16:
            pad_width = 16 - mel_sample.shape[1]
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            mel_sample = np.pad(mel_sample, ((0, 0), (pad_left, pad_right)), 'edge')
        elif mel_sample.shape[1] > 16:
            excess = mel_sample.shape[1] - 16
            start_offset = excess // 2
            mel_sample = mel_sample[:, start_offset:start_offset + 16]

        return mel_sample  

    def extract_embedding(self, mel, num_frames):
        mel_segments = []
        for i in range(num_frames):
            mel_sample = self._get_mel_for_frame(i, mel)
            mel_tensor = torch.tensor(mel_sample, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # [1, 1, 80, 16]
            mel_segments.append(mel_tensor)

        with torch.no_grad():
            batch = torch.cat(mel_segments, dim=0)
            batch = batch.unsqueeze(0).to(self.device)  # (1,T,1,80,16)
            embeddings = self.encoder(batch)  # (T, 512)
            embeddings = embeddings.squeeze(0).cpu().numpy()  # (T,512)

        return embeddings