"""
tester.py
模型测试和评估
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, List, Tuple, Optional
from pathlib import Path
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, mean_absolute_error
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr
import logging
from tqdm import tqdm

logger = logging.getLogger(__name__)


class CircuitTester:
    """电路相似度模型测试器"""
    
    def __init__(self,
                 model: nn.Module,
                 test_loader: DataLoader,
                 device: Optional[torch.device] = None,
                 save_dir: Optional[str] = None):
        """
        Args:
            model: 要测试的模型
            test_loader: 测试数据加载器
            device: 计算设备
            save_dir: 结果保存目录
        """
        self.model = model
        self.test_loader = test_loader
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # 结果保存目录
        self.save_dir = Path(save_dir) if save_dir else Path('./test_results')
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        # 测试结果存储
        self.predictions = []
        self.targets = []
        self.embeddings = []
        
    def test(self) -> Dict:
        """
        执行完整的测试流程
        Returns:
            测试结果字典
        """
        logger.info("Starting model evaluation...")
        
        # 收集预测结果
        self._collect_predictions()
        
        # 计算评估指标
        metrics = self._compute_comprehensive_metrics()
        
        # 生成可视化
        self._generate_visualizations()
        
        # 保存结果
        self._save_results(metrics)
        
        logger.info("Evaluation completed!")
        return metrics
    
    def _collect_predictions(self):
        """收集模型预测结果"""
        self.model.eval()
        
        all_predictions = []
        all_targets = []
        all_embeddings1 = []
        all_embeddings2 = []
        
        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc='Testing'):
                # 将数据移到设备
                graph1, matrix1, graph2, matrix2, distances = [x.to(self.device) for x in batch]
                
                # 获取嵌入
                embedding1 = self.model(graph1, matrix1)
                embedding2 = self.model(graph2, matrix2)
                
                # 计算预测距离
                pred_distances = torch.norm(embedding1 - embedding2, p=2, dim=1)
                
                # 收集结果
                all_predictions.extend(pred_distances.cpu().numpy())
                all_targets.extend(distances.cpu().numpy())
                all_embeddings1.append(embedding1.cpu().numpy())
                all_embeddings2.append(embedding2.cpu().numpy())
        
        self.predictions = np.array(all_predictions)
        self.targets = np.array(all_targets)
        self.embeddings = {
            'embedding1': np.vstack(all_embeddings1),
            'embedding2': np.vstack(all_embeddings2)
        }
        
        logger.info(f"Collected {len(self.predictions)} test samples")
    
    def _compute_comprehensive_metrics(self) -> Dict:
        """计算全面的评估指标"""
        from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
        
        metrics = {}
        
        # 基础回归指标
        metrics['mae'] = mean_absolute_error(self.targets, self.predictions)
        metrics['mse'] = mean_squared_error(self.targets, self.predictions)
        metrics['rmse'] = np.sqrt(metrics['mse'])
        metrics['r2'] = r2_score(self.targets, self.predictions)
        
        # 相关性指标
        if len(self.predictions) > 1:
            metrics['pearson_corr'], metrics['pearson_pval'] = pearsonr(self.targets, self.predictions)
            metrics['spearman_corr'], metrics['spearman_pval'] = spearmanr(self.targets, self.predictions)
        
        # 误差分析
        errors = self.predictions - self.targets
        metrics['mean_error'] = np.mean(errors)
        metrics['std_error'] = np.std(errors)
        metrics['median_error'] = np.median(errors)
        metrics['percentile_95_error'] = np.percentile(np.abs(errors), 95)
        
        # 相对误差
        relative_errors = np.abs(errors) / (self.targets + 1e-8)
        metrics['mape'] = np.mean(relative_errors) * 100  # Mean Absolute Percentage Error
        
        # 按距离范围分析
        metrics['range_analysis'] = self._analyze_by_distance_range()
        
        logger.info("Computed comprehensive metrics")
        return metrics
    
    def _analyze_by_distance_range(self) -> Dict:
        """按距离范围分析性能"""
        # 将目标距离分成几个范围
        percentiles = [0, 25, 50, 75, 100]
        bins = np.percentile(self.targets, percentiles)
        
        range_metrics = {}
        for i in range(len(bins) - 1):
            mask = (self.targets >= bins[i]) & (self.targets < bins[i+1])
            if i == len(bins) - 2:  # 包含最大值
                mask = (self.targets >= bins[i]) & (self.targets <= bins[i+1])
            
            if np.any(mask):
                range_name = f"{percentiles[i]}-{percentiles[i+1]}%"
                range_targets = self.targets[mask]
                range_preds = self.predictions[mask]
                
                range_metrics[range_name] = {
                    'count': int(np.sum(mask)),
                    'mae': float(mean_absolute_error(range_targets, range_preds)),
                    'correlation': float(pearsonr(range_targets, range_preds)[0]) if len(range_targets) > 1 else 0
                }
        
        return range_metrics
    
    def _generate_visualizations(self):
        """生成各种可视化图表"""
        # 1. 预测vs真实散点图
        self._plot_predictions_vs_targets()
        
        # 2. 误差分布图
        self._plot_error_distribution()
        
        # 3. 嵌入空间可视化
        self._plot_embedding_space()
        
        # 4. 相关性热图
        self._plot_correlation_heatmap()
        
        logger.info("Generated all visualizations")
    
    def _plot_predictions_vs_targets(self):
        """绘制预测值vs真实值散点图"""
        plt.figure(figsize=(8, 8))
        
        # 散点图
        plt.scatter(self.targets, self.predictions, alpha=0.5, s=10)
        
        # 理想预测线
        min_val = min(self.targets.min(), self.predictions.min())
        max_val = max(self.targets.max(), self.predictions.max())
        plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect Prediction')
        
        # 计算R²
        from sklearn.metrics import r2_score
        r2 = r2_score(self.targets, self.predictions)
        
        plt.xlabel('True Distance')
        plt.ylabel('Predicted Distance')
        plt.title(f'Predictions vs Targets (R² = {r2:.4f})')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.save_dir / 'predictions_vs_targets.png', dpi=100)
        plt.close()
    
    def _plot_error_distribution(self):
        """绘制误差分布图"""
        errors = self.predictions - self.targets
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # 误差直方图
        axes[0].hist(errors, bins=50, edgecolor='black', alpha=0.7)
        axes[0].axvline(x=0, color='r', linestyle='--', label='Zero Error')
        axes[0].set_xlabel('Prediction Error')
        axes[0].set_ylabel('Frequency')
        axes[0].set_title('Error Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 误差vs真实值
        axes[1].scatter(self.targets, errors, alpha=0.5, s=10)
        axes[1].axhline(y=0, color='r', linestyle='--', label='Zero Error')
        axes[1].set_xlabel('True Distance')
        axes[1].set_ylabel('Prediction Error')
        axes[1].set_title('Error vs True Distance')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.save_dir / 'error_distribution.png', dpi=100)
        plt.close()
    
    def _plot_embedding_space(self):
        """使用t-SNE可视化嵌入空间"""
        # 合并所有嵌入
        all_embeddings = np.vstack([
            self.embeddings['embedding1'][:100],  # 限制数量以加快计算
            self.embeddings['embedding2'][:100]
        ])
        
        if len(all_embeddings) < 2:
            return
        
        # t-SNE降维
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(all_embeddings)-1))
        embeddings_2d = tsne.fit_transform(all_embeddings)
        
        # 绘图
        plt.figure(figsize=(10, 8))
        
        n = len(self.embeddings['embedding1'][:100])
        plt.scatter(embeddings_2d[:n, 0], embeddings_2d[:n, 1], 
                   c='blue', alpha=0.5, label='Circuit Set 1')
        plt.scatter(embeddings_2d[n:, 0], embeddings_2d[n:, 1], 
                   c='red', alpha=0.5, label='Circuit Set 2')
        
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.title('Embedding Space Visualization (t-SNE)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.save_dir / 'embedding_space.png', dpi=100)
        plt.close()
    
    def _plot_correlation_heatmap(self):
        """绘制相关性热图"""
        # 计算不同指标之间的相关性
        data = {
            'Predictions': self.predictions,
            'Targets': self.targets,
            'Errors': self.predictions - self.targets,
            'Abs_Errors': np.abs(self.predictions - self.targets)
        }
        
        # 创建DataFrame并计算相关矩阵
        import pandas as pd
        df = pd.DataFrame(data)
        corr_matrix = df.corr()
        
        # 绘制热图
        plt.figure(figsize=(8, 6))
        sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm', 
                   center=0, square=True, linewidths=1)
        plt.title('Correlation Heatmap')
        plt.tight_layout()
        plt.savefig(self.save_dir / 'correlation_heatmap.png', dpi=100)
        plt.close()
    
    def _save_results(self, metrics: Dict):
        """保存测试结果"""
        # 保存预测结果
        results_path = self.save_dir / 'predictions.npz'
        np.savez(results_path,
                predictions=self.predictions,
                targets=self.targets,
                **self.embeddings)
        
        # 生成文本报告
        self._generate_text_report(metrics)
        
        logger.info(f"Results saved to {self.save_dir}")
    
    def _make_serializable(self, obj):
        """将numpy类型转换为可序列化的Python类型"""
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, dict):
            return {k: self._make_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._make_serializable(item) for item in obj]
        else:
            return obj
    
    def _generate_text_report(self, metrics: Dict):
        """生成文本格式的测试报告"""
        report_path = self.save_dir / 'test_report.txt'
        
        with open(report_path, 'w') as f:
            f.write("=" * 60 + "\n")
            f.write("CIRCUIT SIMILARITY MODEL TEST REPORT\n")
            f.write("=" * 60 + "\n\n")
            
            # 基础指标
            f.write("OVERALL METRICS:\n")
            f.write("-" * 40 + "\n")
            f.write(f"MAE:  {metrics['mae']:.6f}\n")
            f.write(f"MSE:  {metrics['mse']:.6f}\n")
            f.write(f"RMSE: {metrics['rmse']:.6f}\n")
            f.write(f"R²:   {metrics['r2']:.6f}\n")
            f.write(f"MAPE: {metrics['mape']:.2f}%\n\n")
            
            # 相关性指标
            f.write("CORRELATION METRICS:\n")
            f.write("-" * 40 + "\n")
            f.write(f"Pearson:  {metrics.get('pearson_corr', 0):.6f} (p={metrics.get('pearson_pval', 1):.6f})\n")
            f.write(f"Spearman: {metrics.get('spearman_corr', 0):.6f} (p={metrics.get('spearman_pval', 1):.6f})\n\n")
            
            # 误差分析
            f.write("ERROR ANALYSIS:\n")
            f.write("-" * 40 + "\n")
            f.write(f"Mean Error:     {metrics['mean_error']:.6f}\n")
            f.write(f"Std Error:      {metrics['std_error']:.6f}\n")
            f.write(f"Median Error:   {metrics['median_error']:.6f}\n")
            f.write(f"95% Percentile: {metrics['percentile_95_error']:.6f}\n\n")
            
            # 按范围分析
            if 'range_analysis' in metrics:
                f.write("PERFORMANCE BY DISTANCE RANGE:\n")
                f.write("-" * 40 + "\n")
                for range_name, range_metrics in metrics['range_analysis'].items():
                    f.write(f"\n{range_name} percentile:\n")
                    f.write(f"  Count: {range_metrics['count']}\n")
                    f.write(f"  MAE:   {range_metrics['mae']:.6f}\n")
                    f.write(f"  Corr:  {range_metrics['correlation']:.6f}\n")
        
        logger.info(f"Text report saved to {report_path}")


def run_comprehensive_test(model_path: str, test_loader: DataLoader, save_dir: str = './test_results'):
    """
    运行完整的测试流程
    Args:
        model_path: 模型权重文件路径
        test_loader: 测试数据加载器
        save_dir: 结果保存目录
    Returns:
        测试结果字典
    """
    from model import CircuitDistanceModel
    
    # 加载模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CircuitDistanceModel()
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    # 创建测试器
    tester = CircuitTester(model, test_loader, device, save_dir)
    
    # 运行测试
    results = tester.test()
    
    return results