import os
import numpy as np
import torch
import librosa
from tqdm import tqdm
import glob

from audio_encoder import AudioEncoder

class BaseGenerator:
    def __init__(self, wav_path, param_dir, save_prefix,
                 sr=16000, hop_length=200, n_fft=800, n_mels=80, device=None):
        self.wav_path = wav_path
        self.param_dir = param_dir
        self.save_prefix = save_prefix
        self.sr = sr
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.device = device

        self.audio_encoder = AudioEncoder().to(self.device).eval()

        self.mel_spectrogram = self._load_mel()

        self.param_files = sorted(glob.glob(os.path.join(self.param_dir, "[0-9]*", "params.npz")), 
                                 key=lambda x: int(os.path.basename(os.path.dirname(x))))

    def _load_mel(self):
        wav, _ = librosa.load(self.wav_path, sr=self.sr)
        mel = librosa.feature.melspectrogram(
            y=wav,
            sr=self.sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            power=1.0
        )

        mel = librosa.power_to_db(mel, ref=np.max)

        mel = (mel - mel.min()) / (mel.max() - mel.min())

        return mel

    def _get_mel_for_frame(self, frame_idx):
        frame_time = frame_idx * (1 / 25) 
        audio_center_idx = int(frame_time * self.sr)
        mel_center_idx = int(audio_center_idx / self.hop_length)
        mel_start_idx = mel_center_idx - 8
        mel_end_idx = mel_center_idx + 8

        if mel_start_idx < 0:
            mel_end_idx = min(mel_end_idx - mel_start_idx, self.mel_spectrogram.shape[1])
            mel_start_idx = 0
        elif mel_end_idx > self.mel_spectrogram.shape[1]:
            mel_start_idx = max(0, mel_start_idx - (mel_end_idx - self.mel_spectrogram.shape[1]))
            mel_end_idx = self.mel_spectrogram.shape[1]

        mel_sample = self.mel_spectrogram[:, mel_start_idx:mel_end_idx]

        if mel_sample.shape[1] < 16:
            pad_width = 16 - mel_sample.shape[1]
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            mel_sample = np.pad(mel_sample, ((0, 0), (pad_left, pad_right)), 'edge')
        elif mel_sample.shape[1] > 16:
            excess = mel_sample.shape[1] - 16
            start_offset = excess // 2
            mel_sample = mel_sample[:, start_offset:start_offset + 16]

        return mel_sample  

    def _load_pose_params(self):
        pose_params = []
        for path in self.param_files:
            if os.path.exists(path):
                try:
                    data = np.load(path)
                    if 'pose' in data:
                        pose = data['pose']
                        if pose.shape[-1] == 6:
                            pose_params.append(pose)
                except:
                    continue
        return np.array(pose_params)  # shape: [T, 6]

    def run(self):
        pose_params = self._load_pose_params()
        num_frames = pose_params.shape[0]

        mel_segments = []
        for i in range(num_frames):
            mel_patch = self._get_mel_for_frame(i)  # (80,16)
            mel_patch = torch.tensor(mel_patch, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1,1,80,16)
            mel_segments.append(mel_patch)

        with torch.no_grad():
            batch = torch.cat(mel_segments, dim=0)  # (T,1,80,16)
            batch = batch.unsqueeze(0).to(self.device)  # (1,T,1,80,16)
            embeddings = self.audio_encoder(batch)  # (T, 512)
            embeddings = embeddings.squeeze(0).cpu().numpy()  # (T,512)
            print(f"音频嵌入的形状: {embeddings.shape}")

        np.save(f"{self.save_prefix}_embed.npy", embeddings)
        np.save(f"{self.save_prefix}_pose.npy", pose_params)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--wav', type=str,  default='/data/shizhaoxin/visualization_processed/EN_Liu/EN_Liu.wav', help="wav 文件路径")
    parser.add_argument('--params', type=str,  default='/data/shizhaoxin/visualization_processed/EN_Liu/params', help="参数文件夹路径")
    parser.add_argument('--out', type=str, default='/data/shizhaoxin/visualization_processed/EN_Liu/database', help="输出文件前缀")
    args = parser.parse_args()

    generator = BaseGenerator(
        wav_path=args.wav,
        param_dir=args.params,
        save_prefix=args.out,
        device='cuda:0',
    )
    generator.run()
