# rendering/td_binaural_renderer.py

import os
import numpy as np
import soundfile as sf
from scipy.io import loadmat
from scipy.signal import convolve

# 尝试导入tqdm库用于显示进度条
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, **kwargs: x


class TimeDomainBinauralRenderer:
    def __init__(self, hrir_file, fs=48000):
        """
        时域双耳渲染器。
        
        Args:
            hrir_file (str): 球谐域HRIR(头相关冲激响应)数据文件路径 (.mat)。
            fs (int): 采样率。
        """
        self.fs = fs
        self.hrir_file = hrir_file
        
        # 加载HRIR数据
        self.hrir_filters = None
        self._load_hrirs()
    
    def _load_hrirs(self):
        """加载球谐域的HRIR数据。"""
        if not os.path.exists(self.hrir_file):
            raise FileNotFoundError(f"找不到HRIR(双耳渲染)滤波器文件: {self.hrir_file}")
            
        try:
            data = loadmat(self.hrir_file)
            # 根据您的脚本，滤波器变量名为'hnm'
            self.hrir_filters = data['hnm']  # Shape: (filter_len, max_channels, 2)
            print(f"成功加载时域HRIR滤波器: {self.hrir_filters.shape}")
        except KeyError:
            raise ValueError(f"在HRIR文件 '{self.hrir_file}' 中找不到 'hnm' 变量。")
        except Exception as e:
            raise IOError(f"加载或解析HRIR文件失败: {e}")

    def render(self, hoa_signals):
        """
        将时域HOA信号渲染为双耳信号。
        
        核心逻辑是通过时域卷积将每个HOA音频通道与对应的HRIR进行混合。
        
        Args:
            hoa_signals (np.ndarray): 时域HOA信号, 形状为 (n_samples, n_channels)。
            
        Returns:
            np.ndarray: 渲染后的双耳时域信号, 形状为 (n_output_samples, 2)。
        """
        print("--- 开始双耳渲染 (时域卷积法) ---")
        
        # --- 准备输入和输出 ---
        n_samples, n_channels_input = hoa_signals.shape
        filter_len, n_channels_max, _ = self.hrir_filters.shape
        
        print(f"音频输入: {n_samples}采样点, {n_channels_input}通道。")
        print(f"HRIR滤波器库: {filter_len}点, 支持最高{n_channels_max}通道。")

        if n_channels_input > n_channels_max:
            raise ValueError(f"输入信号有 {n_channels_input} 个通道, 但HRIR库仅支持 {n_channels_max} 个。")
        
        # 只使用与输入信号通道数匹配的滤波器
        active_filters = self.hrir_filters[:, :n_channels_input, :]
        output_len = n_samples + filter_len - 1
        binaural_output = np.zeros((output_len, 2), dtype=np.float64)

        # --- 执行时域卷积（核心循环） ---
        print(f"🚀 开始对 {n_channels_input} 个通道进行卷积...")
        
        # 使用tqdm创建进度条
        channel_iterator = tqdm(range(n_channels_input), desc="渲染进度", unit="ch")
        
        for ch in channel_iterator:
            signal_ch = hoa_signals[:, ch]
            filter_left = active_filters[:, ch, 0]
            filter_right = active_filters[:, ch, 1]
            
            # 累加每个通道的卷积结果
            binaural_output[:, 0] += convolve(signal_ch, filter_left, mode='full')
            binaural_output[:, 1] += convolve(signal_ch, filter_right, mode='full')

        # --- 幅度归一化 ---
        peak_val = np.max(np.abs(binaural_output))
        if peak_val > 1e-8:
            binaural_output /= peak_val
            print(f"\n✅ 渲染完成！输出信号已进行峰值归一化 (原峰值为 {peak_val:.4f})。")
        else:
            print("\n✅ 渲染完成！输出信号为静音。")

        return binaural_output

    def process_and_save(self, hoa_signals, output_path):
        """
        一个便捷方法，用于渲染并直接保存结果。
        
        Args:
            hoa_signals (np.ndarray): 时域HOA信号 (n_samples, n_channels)。
            output_path (str): 保存最终 .wav 文件的路径。
        """
        # 渲染
        binaural_signal = self.render(hoa_signals)
        
        # 保存
        # 将浮点数转换为16位整数格式
        audio_to_save = np.int16(binaural_signal * 32767)
        
        # 确保输出目录存在
        output_dir = os.path.dirname(output_path)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        sf.write(output_path, audio_to_save, self.fs)
        print(f"🎉 处理成功！最终的双耳音频已保存到: '{output_path}'")