import torch
import numpy as np
import os
from audio_processing import AudioProcessor
from pose_library import PoseLibrary
from retrieval import RetrievalEngine
from gan_refiner import PoseGAN
from scipy.ndimage import gaussian_filter1d
from scipy.signal import savgol_filter

class Smoother:
    def __init__(self):
        pass
        
    def gaussian_smooth(self, poses, sigma=2.0):
        """
        Apply Gaussian smoothing to pose sequence
        
        Args:
            poses: numpy array of shape [T, 1, 6]
            sigma: standard deviation for Gaussian kernel
            
        Returns:
            Smoothed poses of same shape
        """
        poses_flat = poses.squeeze(1)  # [T, 6]
        
        smoothed = np.zeros_like(poses_flat)
        for i in range(poses_flat.shape[1]):
            smoothed[:, i] = gaussian_filter1d(poses_flat[:, i], sigma=sigma)
            
        return smoothed.reshape(poses.shape)
    
    def savgol_smooth(self, poses, window_length=15, polyorder=3):
        """
        Apply Savitzky-Golay filter for smoothing
        
        Args:
            poses: numpy array of shape [T, 1, 6]
            window_length: length of the filter window (must be odd)
            polyorder: order of the polynomial used for filtering
            
        Returns:
            Smoothed poses of same shape
        """
        if window_length % 2 == 0:
            window_length += 1  # Ensure window length is odd
            
        if window_length > poses.shape[0]:
            window_length = min(poses.shape[0] - (poses.shape[0] % 2 == 0), 5)
            polyorder = min(polyorder, window_length - 1)
            
        poses_flat = poses.squeeze(1)  # [T, 6]
        
        smoothed = np.zeros_like(poses_flat)
        for i in range(poses_flat.shape[1]):
            smoothed[:, i] = savgol_filter(poses_flat[:, i], window_length, polyorder)
            
        return smoothed.reshape(poses.shape)
    
    def adaptive_ema(self, poses, alpha_min=0.1, alpha_max=0.5, motion_threshold=0.01):
        """
        Apply adaptive exponential moving average that adjusts smoothing 
        based on motion magnitude
        
        Args:
            poses: numpy array of shape [T, 1, 6]
            alpha_min: minimum smoothing factor (more smoothing)
            alpha_max: maximum smoothing factor (less smoothing)
            motion_threshold: threshold to determine significant motion
            
        Returns:
            Smoothed poses of same shape
        """
        poses_flat = poses.squeeze(1)  # [T, 6]
        smoothed = np.zeros_like(poses_flat)
        
        smoothed[0] = poses_flat[0]
        
        diffs = np.abs(np.diff(poses_flat, axis=0))
        
        for t in range(1, poses_flat.shape[0]):
            motion_magnitude = np.mean(diffs[t-1])
            alpha = alpha_min + (alpha_max - alpha_min) * min(motion_magnitude / motion_threshold, 1.0)
            
            smoothed[t] = alpha * poses_flat[t] + (1 - alpha) * smoothed[t-1]
            
        return smoothed.reshape(poses.shape)
    
    def bilateral_filter(self, poses, sigma_space=2.0, sigma_motion=0.05):
        """
        Apply bilateral filtering that preserves sharp transitions while smoothing jitters
        
        Args:
            poses: numpy array of shape [T, 1, 6]
            sigma_space: spatial sigma (similar to Gaussian filter)
            sigma_motion: sigma for motion similarity weighting
            
        Returns:
            Smoothed poses of same shape
        """
        poses_flat = poses.squeeze(1)  # [T, 6]
        T = poses_flat.shape[0]
        smoothed = np.zeros_like(poses_flat)
        
        for t in range(T):
            weights_sum = 0
            value_sum = np.zeros(poses_flat.shape[1])
            
            window_radius = int(sigma_space * 3)
            
            for dt in range(-window_radius, window_radius + 1):
                if 0 <= t + dt < T:
                    w_space = np.exp(-(dt**2) / (2 * sigma_space**2))
                    
                    diff = poses_flat[t + dt] - poses_flat[t]
                    w_motion = np.exp(-np.sum(diff**2) / (2 * sigma_motion**2))
                    
                    weight = w_space * w_motion
                    
                    value_sum += weight * poses_flat[t + dt]
                    weights_sum += weight
            
            smoothed[t] = value_sum / weights_sum
            
        return smoothed.reshape(poses.shape)

def main(wav_path, lib_dir, save_path='output_pose.npy', window_size=32, stride=16, 
         smooth_method='bilateral', smooth_strength='medium', exp_coeffs_path=None):
    
    device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
    
    os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
    
    lib = PoseLibrary(window_size=window_size, stride=stride)
    lib.load_library(lib_dir)
    
    processor = AudioProcessor(device=device)
    wav, num_frames = processor.load_wav(wav_path)

    mel = processor.compute_mel(wav)
    embeddings = processor.extract_embedding(mel, num_frames)

    retriever = RetrievalEngine(lib)
    pose_segments = retriever.retrieve_sequence(embeddings)
    pose_retrieved = torch.tensor(np.array(pose_segments), dtype=torch.float32).to(device)

    # 1. 加载表情参数
    if exp_coeffs_path is None:
        raise ValueError("exp_coeffs_path must be provided for expression features.")
    exp_data = np.load(exp_coeffs_path)
    if 'exp_coeff' in exp_data:
        exp_feat = exp_data['exp_coeff']  # [T, 64]
    else:
        raise ValueError("exp_coeff not found in npz file.")
    exp_feat = torch.tensor(exp_feat, dtype=torch.float32).to(device)

    gan_model = PoseGAN(
        pose_dim=6,
        audio_dim=512,
        hidden_dim=256,
        device=device
    )

    gan_model.load_models('/home/shizhaoxin/codebase/AudioDrivenGaussian/checkpoints/retrievalpose_Alex/audio2pose_epoch1430.pt')
    
    max_length = max(pose_retrieved.shape[0], embeddings.shape[0], exp_feat.shape[0])

    # 2. 对齐长度
    if pose_retrieved.shape[0] < max_length:
        padding_length = max_length - pose_retrieved.shape[0]
        last_pose_frame = pose_retrieved[-1:].repeat(padding_length, 1, 1)
        pose_retrieved = torch.cat([pose_retrieved, last_pose_frame], dim=0)

    embeddings = torch.tensor(embeddings, dtype=torch.float32).to(device)
    if embeddings.shape[0] < max_length:
        padding_length = max_length - embeddings.shape[0]
        last_embedding_frame = embeddings[-1:].repeat(padding_length, 1)
        embeddings = torch.cat([embeddings, last_embedding_frame], dim=0)

    if exp_feat.shape[0] < max_length:
        padding_length = max_length - exp_feat.shape[0]
        last_exp_frame = exp_feat[-1:].repeat(padding_length, 1)
        exp_feat = torch.cat([exp_feat, last_exp_frame], dim=0)

    segment_size = 512
    refined_segments = []
    for i in range(0, pose_retrieved.shape[0], segment_size):
        end_idx = min(i + segment_size, pose_retrieved.shape[0])
        pose_segment = pose_retrieved[i:end_idx]

        refined_segments.append(pose_segment)

    refined_pose = torch.cat(refined_segments, dim=0)
    refined_pose_np = refined_pose.cpu().numpy()
    
    np.save(save_path, refined_pose_np)

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--wav', type=str, default='/data/shizhaoxin/selected/Alex/Alex.wav', help="音频文件路径")
    parser.add_argument('--lib', type=str, default='/data/shizhaoxin/selected/Alex', help="姿态库目录")
    parser.add_argument('--out', type=str, default='/data/shizhaoxin/selected/Alex/output_pose_ablation.npy', help="输出姿态文件")
    parser.add_argument('--window', type=int, default=32, help="姿态窗口大小")
    parser.add_argument('--stride', type=int, default=16, help="姿态窗口步长")
    parser.add_argument('--smooth', type=str, default='adaptive_ema', 
                        choices=['gaussian', 'savgol', 'adaptive_ema', 'bilateral'],
                        help="平滑算法类型")
    parser.add_argument('--smooth_strength', type=str, default='strong',
                        choices=['light', 'medium', 'strong'],
                        help="平滑强度")
    # 4. 新增表情参数路径
    parser.add_argument('--exp_coeffs_path', type=str, default='/data/shizhaoxin/selected/Alex/predicted_exp_coeffs_ablation.npz', help="表情参数npz路径")

    args = parser.parse_args()
    
    main(args.wav, args.lib, args.out, args.window, args.stride, 
         args.smooth, args.smooth_strength, args.exp_coeffs_path)