"""
Emotion Retrieval System - DAC Retrieval Agent Implementation

情感原型检索系统，用于:
1. 构建情感嵌入数据库
2. 基于文本指令检索相似情感样本
3. 基于音频情感嵌入检索相似样本
"""

import os
import json
import logging
import pickle
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


@dataclass
class AudioResult:
    """检索结果数据类"""
    audio_path: str
    instruction: str
    similarity_score: float
    dataset_name: str
    speaker_name: str
    json_path: str
    emotion_score: float = 0.0
    combined_score: float = 0.0


class EmotionRetrievalSystem:
    """
    情感原型检索系统 - DAC Retrieval Agent 实现
    
    支持两种检索模式:
    1. 文本指令检索: 基于语义相似度检索情感样本
    2. 音频情感检索: 基于情感嵌入相似度检索样本
    """

    def __init__(self, embedding_model_name: str = None, device: str = None):
        """
        初始化检索系统
        
        Args:
            embedding_model_name: 文本嵌入模型名称 (可选)
            device: 计算设备
        """
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.embedding_model = None
        self.embedding_model_name = embedding_model_name
        
        # 数据库组件
        self.audio_metadata: List[Dict] = []
        self.embeddings: Optional[torch.Tensor] = None           # 指令嵌入
        self.emotion_embeddings: Optional[torch.Tensor] = None   # 情感嵌入
        
        # 尝试加载文本嵌入模型
        if embedding_model_name:
            self._load_embedding_model(embedding_model_name)
        
        logger.info(f"EmotionRetrievalSystem 初始化完成, device={self.device}")

    def _load_embedding_model(self, model_name: str):
        """加载文本嵌入模型"""
        try:
            from sentence_transformers import SentenceTransformer
            self.embedding_model = SentenceTransformer(model_name, device=self.device)
            logger.info(f"加载嵌入模型: {model_name}")
        except ImportError:
            logger.warning("sentence_transformers 未安装，文本检索将使用模拟嵌入")
        except Exception as e:
            logger.warning(f"加载嵌入模型失败: {e}")

    def load_database(self, db_path: str) -> bool:
        """
        加载预构建的数据库
        
        Args:
            db_path: 数据库目录路径
            
        Returns:
            是否加载成功
        """
        db_path = Path(db_path)
        metadata_path = db_path / "metadata.pkl"
        embeddings_path = db_path / "embeddings.pt"
        emotion_embeddings_path = db_path / "emotion_embeddings.pt"
        
        try:
            # 加载元数据
            if metadata_path.exists():
                with open(metadata_path, 'rb') as f:
                    self.audio_metadata = pickle.load(f)
                logger.info(f"加载 {len(self.audio_metadata)} 条元数据")
            
            # 加载指令嵌入
            if embeddings_path.exists():
                self.embeddings = torch.load(embeddings_path, map_location=self.device)
                self.embeddings = F.normalize(self.embeddings, p=2, dim=1)
                logger.info(f"加载指令嵌入: {self.embeddings.shape}")
            
            # 加载情感嵌入
            if emotion_embeddings_path.exists():
                self.emotion_embeddings = torch.load(emotion_embeddings_path, map_location=self.device)
                self.emotion_embeddings = F.normalize(self.emotion_embeddings, p=2, dim=1)
                logger.info(f"加载情感嵌入: {self.emotion_embeddings.shape}")
            
            return len(self.audio_metadata) > 0
            
        except Exception as e:
            logger.error(f"加载数据库失败: {e}")
            return False

    def build_database(self, data_root: str, save_path: str = None) -> bool:
        """
        从数据目录构建数据库
        
        Args:
            data_root: 数据根目录
            save_path: 保存路径
            
        Returns:
            是否构建成功
        """
        logger.info(f"从 {data_root} 构建数据库...")
        
        # 查找所有嵌入文件
        embedding_files = self._find_embedding_files(data_root)
        if not embedding_files:
            logger.error("未找到嵌入文件")
            return False
        
        logger.info(f"找到 {len(embedding_files)} 个嵌入文件")
        
        embeddings_list = []
        emotion_embeddings_list = []
        metadata_list = []
        
        for embed_path, audio_path, json_path in embedding_files:
            try:
                # 加载指令嵌入
                embedding = torch.load(embed_path, map_location='cpu')
                if isinstance(embedding, torch.Tensor):
                    embedding = embedding.numpy()
                embeddings_list.append(embedding)
                
                # 加载情感嵌入 (如果存在)
                emotion_embed_path = Path(audio_path).parent / (Path(audio_path).stem + "-emotion_embed.pt")
                if emotion_embed_path.exists():
                    emo_emb = torch.load(str(emotion_embed_path), map_location='cpu')
                    if isinstance(emo_emb, torch.Tensor):
                        emo_emb = emo_emb.numpy()
                else:
                    emo_emb = np.zeros((1, 1280), dtype=np.float32)
                emotion_embeddings_list.append(emo_emb)
                
                # 加载元数据
                instruction = self._load_instruction_from_json(json_path)
                path_parts = Path(audio_path).parts
                
                metadata_list.append({
                    'audio_path': str(audio_path),
                    'json_path': str(json_path),
                    'instruction': instruction or "",
                    'dataset_name': path_parts[-3] if len(path_parts) >= 3 else "unknown",
                    'speaker_name': path_parts[-2] if len(path_parts) >= 2 else "unknown"
                })
                
            except Exception as e:
                logger.warning(f"处理 {embed_path} 失败: {e}")
                continue
        
        if not embeddings_list:
            logger.error("没有有效的嵌入数据")
            return False
        
        # 转换为张量
        self.embeddings = torch.from_numpy(np.vstack(embeddings_list).astype(np.float32)).to(self.device)
        self.embeddings = F.normalize(self.embeddings, p=2, dim=1)
        
        self.emotion_embeddings = torch.from_numpy(np.vstack(emotion_embeddings_list).astype(np.float32)).to(self.device)
        self.emotion_embeddings = F.normalize(self.emotion_embeddings, p=2, dim=1)
        
        self.audio_metadata = metadata_list
        
        logger.info(f"数据库构建完成: {len(self.embeddings)} 条记录")
        
        # 保存数据库
        if save_path:
            self.save_database(save_path)
        
        return True

    def _find_embedding_files(self, data_root: str) -> List[Tuple[str, str, str]]:
        """查找所有嵌入文件"""
        embedding_files = []
        data_root = Path(data_root)
        
        for embed_file in data_root.rglob("*-instruction_embed.pt"):
            stem = embed_file.stem.replace("-instruction_embed", "")
            
            # 查找对应的音频文件
            audio_path = None
            for ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
                potential_audio = embed_file.parent / f"{stem}{ext}"
                if potential_audio.exists():
                    audio_path = potential_audio
                    break
            
            # 查找对应的 JSON 文件
            json_path = embed_file.parent / f"{stem}.json"
            
            if audio_path and json_path.exists():
                embedding_files.append((str(embed_file), str(audio_path), str(json_path)))
        
        return embedding_files

    def _load_instruction_from_json(self, json_path: str) -> Optional[str]:
        """从 JSON 文件加载指令"""
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data.get('fine_grained_instructions', None)
        except:
            return None

    def search_by_instruction(self, query: str, top_k: int = 5) -> List[AudioResult]:
        """
        基于文本指令检索相似样本
        
        Args:
            query: 查询文本
            top_k: 返回结果数量
            
        Returns:
            AudioResult 列表
        """
        if self.embeddings is None:
            logger.error("数据库未加载")
            return []
        
        # 生成查询嵌入
        if self.embedding_model:
            query_embedding = self.embedding_model.encode([query])
            query_embedding = torch.from_numpy(np.array(query_embedding).astype(np.float32)).to(self.device)
        else:
            # 模拟嵌入
            query_embedding = torch.randn(1, self.embeddings.shape[1], device=self.device)
        
        query_embedding = F.normalize(query_embedding, p=2, dim=1)
        
        # 计算余弦相似度
        similarities = torch.matmul(query_embedding, self.embeddings.t())
        
        # 获取 top_k 结果
        top_k = min(top_k, len(self.embeddings))
        top_values, top_indices = torch.topk(similarities[0], k=top_k)
        
        results = []
        for sim, idx in zip(top_values.cpu().numpy(), top_indices.cpu().numpy()):
            metadata = self.audio_metadata[idx]
            results.append(AudioResult(
                audio_path=metadata['audio_path'],
                instruction=metadata['instruction'],
                similarity_score=float(sim),
                dataset_name=metadata['dataset_name'],
                speaker_name=metadata['speaker_name'],
                json_path=metadata['json_path']
            ))
        
        return results

    def search_by_emotion(self, emotion_embedding: torch.Tensor, top_k: int = 5) -> List[AudioResult]:
        """
        基于情感嵌入检索相似样本
        
        Args:
            emotion_embedding: 情感嵌入张量 (1, dim)
            top_k: 返回结果数量
            
        Returns:
            AudioResult 列表
        """
        if self.emotion_embeddings is None:
            logger.error("情感嵌入未加载")
            return []
        
        # 归一化查询嵌入
        query_embedding = emotion_embedding.to(self.device, dtype=self.emotion_embeddings.dtype)
        query_embedding = F.normalize(query_embedding, p=2, dim=1)
        
        # 计算余弦相似度
        similarities = torch.matmul(query_embedding, self.emotion_embeddings.t())
        
        # 获取 top_k 结果
        top_k = min(top_k, len(self.emotion_embeddings))
        top_values, top_indices = torch.topk(similarities[0], k=top_k)
        
        results = []
        for sim, idx in zip(top_values.cpu().numpy(), top_indices.cpu().numpy()):
            metadata = self.audio_metadata[idx]
            results.append(AudioResult(
                audio_path=metadata['audio_path'],
                instruction=metadata['instruction'],
                similarity_score=float(sim),
                emotion_score=float(sim),
                combined_score=float(sim),
                dataset_name=metadata['dataset_name'],
                speaker_name=metadata['speaker_name'],
                json_path=metadata['json_path']
            ))
        
        return results

    def retrieve(self, user_request: str, top_k: int = 5, rerank: bool = False) -> List[AudioResult]:
        """
        统一检索接口 (兼容 OmniTTS)
        
        Args:
            user_request: 用户请求文本
            top_k: 返回结果数量
            rerank: 是否重排序 (需要 Gemini)
            
        Returns:
            AudioResult 列表
        """
        return self.search_by_instruction(user_request, top_k)

    def save_database(self, save_path: str):
        """保存数据库到磁盘"""
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        with open(save_path / "metadata.pkl", 'wb') as f:
            pickle.dump(self.audio_metadata, f)
        
        torch.save(self.embeddings.cpu(), save_path / "embeddings.pt")
        
        if self.emotion_embeddings is not None:
            torch.save(self.emotion_embeddings.cpu(), save_path / "emotion_embeddings.pt")
        
        logger.info(f"数据库已保存到 {save_path}")
