import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
import glob
from torch.utils.data import Dataset


class Audio2PoseDataset(Dataset):
    def __init__(self,cfg):

        self.wav_path = cfg['wav_path']
        self.time_length = cfg['time_length']
        self.param_path = cfg['param_path']
        
        self.sr = 16000  
        self.n_fft = 800 
        self.hop_length = 200  
        self.n_mels = 80 
        
        self.audio, _ = librosa.load(self.wav_path, sr=self.sr)
        self.mel_spectrogram = self._audio_to_mel(self.audio)
        
        self.param_files = sorted(glob.glob(os.path.join(self.param_path, "[0-9]*", "params.npz")), 
                                 key=lambda x: int(os.path.basename(os.path.dirname(x))))
        
        self.num_frames = len(self.param_files)

        self.pose_params = self._load_pose_params()
        
        self.exp_params = self._load_exp_params()

    def _audio_to_mel(self, audio):
        """将音频转换为mel谱"""

        mel = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels
        )

        mel = librosa.power_to_db(mel, ref=np.max)

        mel = (mel - mel.min()) / (mel.max() - mel.min())
        return mel
    
    def _load_pose_params(self):
        """加载所有姿态参数，从params.npz文件中提取pose_coeff"""
        pose_params = []
        for param_path in self.param_files:
            if os.path.exists(param_path):
                try:
                    param = np.load(param_path)
                    if 'pose' in param:
                        pose_params.append(param['pose'])
                    else:
                        print(f"Warning: 'pose' not found in {param_path}")
                        print(f"Available keys: {list(param.keys())}")
                except Exception as e:
                    print(f"Error loading {param_path}: {e}")
        
        if not pose_params:
            raise ValueError("No pose parameters were loaded. Check file paths and parameter names.")
            
        return np.array(pose_params)
    
    def _load_exp_params(self):
        """加载所有表情参数，从params.npz文件中提取exp_coeff"""
        exp_params = []
        for param_path in self.param_files:
            if os.path.exists(param_path):
                try:
                    param = np.load(param_path)
                    if 'exp_coeff' in param:
                        exp_params.append(param['exp_coeff'])
                    else:
                        print(f"Warning: 'exp_coeff' not found in {param_path}")
                        print(f"Available keys: {list(param.keys())}")
                except Exception as e:
                    print(f"Error loading {param_path}: {e}")
        
        if not exp_params:
            raise ValueError("No expression parameters were loaded. Check file paths and parameter names.")
            
        return np.array(exp_params)
    
    def _get_mel_for_frame(self, frame_idx):
        frame_time = frame_idx * (1/25)  
        audio_center_idx = int(frame_time * self.sr)

        mel_center_idx = int(audio_center_idx / self.hop_length)

        mel_start_idx = mel_center_idx - 8
        mel_end_idx = mel_center_idx + 8
        
        if mel_start_idx < 0:
            mel_end_idx = min(mel_end_idx - mel_start_idx, self.mel_spectrogram.shape[1])
            mel_start_idx = 0
        elif mel_end_idx > self.mel_spectrogram.shape[1]:
            mel_start_idx = max(0, mel_start_idx - (mel_end_idx - self.mel_spectrogram.shape[1]))
            mel_end_idx = self.mel_spectrogram.shape[1]

        mel_sample = self.mel_spectrogram[:, mel_start_idx:mel_end_idx]
        
        if mel_sample.shape[1] < 16:
            pad_width = 16 - mel_sample.shape[1]
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            mel_sample = np.pad(mel_sample, ((0, 0), (pad_left, pad_right)), 'edge')
        elif mel_sample.shape[1] > 16:
            excess = mel_sample.shape[1] - 16
            start_offset = excess // 2
            mel_sample = mel_sample[:, start_offset:start_offset+16]
            
        return mel_sample
        
    def __len__(self):
        return max(0, self.num_frames - self.time_length + 1)
        
    def __getitem__(self, idx):
        if idx < 0 or idx >= self.__len__():
            raise IndexError(f"Index {idx} out of range for dataset of length {self.__len__()}")
        
        pose_squence = self.pose_params[idx:idx+self.time_length] # [T,1,6]
        
        exp_squence = self.exp_params[idx:idx+self.time_length] # [T,1,64]
        
        mel_sequence = np.zeros((self.time_length, 1, self.n_mels, 16)) # [T, 1, 80, 16]

        for i in range(self.time_length):
            frame_idx = idx + i
            mel_sample = self._get_mel_for_frame(frame_idx)
            mel_sequence[i, 0] = mel_sample
            
        return {
            'mel': torch.FloatTensor(mel_sequence),
            'pose_gt': torch.FloatTensor(pose_squence).squeeze(1),
            'exp_gt': torch.FloatTensor(exp_squence).squeeze(1),
        }