"""
inference.py
电路相似度推理脚本
"""

import torch
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple
import logging
import argparse
from data_loader import CircuitDataset
from model import CircuitDistanceModel
import matplotlib.pyplot as plt
import seaborn as sns
import shutil  # 添加shutil用于文件复制

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_model(model_path: str, device: torch.device) -> CircuitDistanceModel:
    """加载训练好的模型"""
    try:
        # 加载模型检查点
        checkpoint = torch.load(model_path, map_location=device)
        
        # 获取模型配置
        config = checkpoint.get('config', {})
        
        # 初始化模型
        model = CircuitDistanceModel(
            node_feature_dim=config.get('node_feature_dim', 7),
            embedding_dim=config.get('embedding_dim', 256),
            fusion_method=config.get('fusion_method', 'attention'),
            num_attention_heads=config.get('num_attention_heads', 8)
        )
        
        # 加载模型权重
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        logger.info(f"成功加载模型: {model_path}")
        return model
        
    except Exception as e:
        logger.error(f"加载模型失败: {e}")
        raise

def extract_embeddings(model: CircuitDistanceModel, dataset: CircuitDataset, device: torch.device, batch_size: int = 32) -> List[Tuple[str, np.ndarray]]:
    """提取所有电路的嵌入向量，返回 (circuit_id, embedding) 列表"""
    results = []
    
    # 逐个处理电路（避免批处理中的数据类型问题）
    with torch.no_grad():
        for idx in range(len(dataset)):
            # 获取单个电路数据
            graph_data, matrix, vector = dataset[idx]
            circuit_id = dataset.circuit_data[idx]['id']
            
            # 移动到设备
            graph_data = graph_data.to(device)
            matrix = matrix.unsqueeze(0).to(device)  # 添加batch维度
            
            # 前向传播获取embedding
            embedding = model(graph_data, matrix)
            embedding_np = embedding.squeeze(0).cpu().numpy()  # 移除batch维度
            
            # 保存结果
            results.append((circuit_id, embedding_np))
            
            if (idx + 1) % 10 == 0 or (idx + 1) == len(dataset):
                logger.info(f"已处理 {idx + 1}/{len(dataset)} 个电路")
    
    return results

def save_results_per_circuit(results: List[Tuple[str, np.ndarray]], output_dir: Path, test_dir: Path):
    """为每个电路创建单独的文件夹，保存嵌入和元信息，并复制vector.npy文件"""
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for circuit_id, embedding in results:
        # 创建电路专属目录
        circuit_dir = output_dir / circuit_id
        circuit_dir.mkdir(exist_ok=True)
        
        # 保存嵌入向量
        np.save(circuit_dir / "embedding.npy", embedding)
        
        # 复制vector.npy文件
        source_vector_path = test_dir / circuit_id / "vector.npy"
        if source_vector_path.exists():
            shutil.copy2(source_vector_path, circuit_dir / "vector.npy")
            logger.info(f"已复制vector.npy到 {circuit_dir}")
        else:
            logger.warning(f"未找到源vector.npy文件: {source_vector_path}")
    
    logger.info(f"所有电路的嵌入和vector.npy文件已保存到: {output_dir}")

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="电路嵌入提取推理")
    parser.add_argument("--model_path", type=str, required=True, help="模型文件路径 (.pt)")
    parser.add_argument("--test_dir", type=str, default="./data/test", help="测试数据目录")
    parser.add_argument("--output_dir", type=str, default="./inference_results", help="输出目录")
    parser.add_argument("--batch_size", type=int, default=16, help="批处理大小")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", 
                        help="设备: 'cuda' 或 'cpu'")
    
    args = parser.parse_args()
    
    # 设置设备
    device = torch.device(args.device)
    logger.info(f"使用设备: {device}")
    
    try:
        # 加载模型
        model = load_model(args.model_path, device)
        
        # 创建测试数据集
        test_dir = Path(args.test_dir)
        test_dataset = CircuitDataset(
            data_dir=test_dir,
            mode='test',
            normalize_features=True,
            matrix_size=(64, 64)  # 与训练时保持一致
        )
        
        if len(test_dataset) == 0:
            logger.error(f"测试目录中没有找到电路数据: {test_dir}")
            return
        
        logger.info(f"找到 {len(test_dataset)} 个测试电路")
        
        # 提取嵌入向量（每个电路）
        results = extract_embeddings(model, test_dataset, device, args.batch_size)
        
        # 保存每个电路的结果，并复制vector.npy文件
        output_dir = Path(args.output_dir)
        save_results_per_circuit(results, output_dir, test_dir)
        
        logger.info("嵌入提取和文件复制完成！每个电路已单独保存。")
        
    except Exception as e:
        logger.error(f"推理过程中发生错误: {e}")
        raise

if __name__ == "__main__":
    main()