#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统一可视化模块 - 支持多设备、多方法的消融实验对比
Unified Visualizer - Multi-device, multi-method ablation study comparison
"""

import os
import json
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Dict, Tuple, Optional
from pel_nas.search.evaluator import ArchitectureInfo, filter_valid_architectures, get_robust_axis_limits
from pel_nas.data.pareto_metrics import ParetoMetrics, evaluate_pareto_performance
from experiments.ablation.ablation_studies import AblationResult

logger = logging.getLogger(__name__)

# 设置高质量学术图表样式
plt.rcParams.update({
    'font.size': 20,
    'font.family': 'Liberation Sans',
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.15,
    'legend.fontsize': 18,
    'legend.frameon': True,
    'legend.fancybox': False,
    'legend.shadow': False,
    'legend.edgecolor': 'none',
    'legend.framealpha': 0.9,
    'axes.titlesize': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 15,
    'ytick.labelsize': 15,
})

class UnifiedAblationVisualizer:
    """统一消融实验可视化器"""
    
    def __init__(self, dataset: str = 'cifar100'):
        """初始化可视化器
        
        Args:
            dataset: 数据集名称
        """
        self.dataset = dataset
        self.metrics_calculator = ParetoMetrics()
        
        # 六个硬件设备
        self.devices = ['edgegpu', 'raspi4', 'edgetpu', 'pixel3', 'eyeriss', 'fpga']
        
        # 四种方法的颜色和样式配置
        self.method_styles = {
            'PEL-NAS (Ours)': {
                'color': '#FF0000',
                'linestyle': '-',
                'marker': '*',
                'markersize': 12,
                'linewidth': 3,
                'alpha': 0.9,
                'zorder': 7
            },
            'Without Partitioning': {
                'color': '#FF8C00',
                'linestyle': '--',
                'marker': 'o',
                'markersize': 8,
                'linewidth': 2.5,
                'alpha': 0.8,
                'zorder': 6
            },
            'Without LLM (uses PEA)': {
                'color': '#32CD32',
                'linestyle': '-.',
                'marker': 's',
                'markersize': 8,
                'linewidth': 2.5,
                'alpha': 0.8,
                'zorder': 5
            },
            'Without ZC Ensemble (uses Synflow)': {
                'color': '#4169E1',
                'linestyle': ':',
                'marker': '^',
                'markersize': 8,
                'linewidth': 2.5,
                'alpha': 0.8,
                'zorder': 4
            }
        }
        
        # 设备名称映射
        self.device_names = {
            'edgegpu': 'EdgeGPU',
            'raspi4': 'Raspi4',
            'edgetpu': 'EdgeTPU',
            'pixel3': 'Pixel3',
            'eyeriss': 'Eyeriss',
            'fpga': 'FPGA'
        }
        
        # 数据集特定的纵轴范围（Accuracy）
        self.dataset_axis_ranges = {
            'cifar10': (70, 100),
            'cifar100': (50, 80),
            'ImageNet16-120': (20, 50),
        }
        
        logger.info(f"🎨 统一消融实验可视化器初始化完成 - {dataset}")
    
    def create_unified_comparison_figure(self, 
                                       all_device_results: Dict[str, List[AblationResult]],
                                       output_dir: str) -> Tuple[str, Dict]:
        """创建统一的六设备消融实验对比图
        
        Args:
            all_device_results: 所有设备的消融实验结果
            output_dir: 输出目录
            
        Returns:
            图片路径和指标结果
        """
        logger.info("🎨 创建统一消融实验对比图...")
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 创建3x2子图布局
        fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharey='row')
        axes = axes.flatten()
        
        # 存储所有指标结果
        all_metrics = {}
        
        # 为每个设备创建子图
        for i, device in enumerate(self.devices):
            if device not in all_device_results:
                logger.warning(f"⚠️ 设备 {device} 没有消融实验结果")
                continue
                
            device_results = all_device_results[device]
            device_metrics = self._create_device_subplot(axes[i], device, device_results)
            all_metrics[device] = device_metrics
            
            # 为每行的第一个子图添加纵轴标签
            if i % 3 == 0:
                axes[i].set_ylabel('Accuracy (%)', fontsize=20, fontweight='normal')
        
        # 创建全局图例
        self._create_global_legend(fig)
        
        # 设置整体布局
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.12, top=0.95)
        
        # 保存图片（PNG + PDF）
        figure_path = os.path.join(output_dir, f'unified_ablation_comparison_{self.dataset}.png')
        pdf_path = os.path.join(output_dir, f'unified_ablation_comparison_{self.dataset}.pdf')
        plt.savefig(figure_path, dpi=300, bbox_inches='tight', facecolor='white')
        try:
            plt.savefig(pdf_path, dpi=300, bbox_inches='tight', facecolor='white')
        except Exception as e:
            logger.warning(f"Failed to save PDF figure: {e}")
        plt.close()
        
        logger.info(f"✅ 统一消融对比图已保存: {os.path.basename(figure_path)}")
        
        # 生成指标总结
        metrics_summary = self._generate_metrics_summary(all_metrics, output_dir)
        
        return figure_path, all_metrics
    
    def _create_device_subplot(self, ax, device: str, device_results: List[AblationResult]) -> Dict:
        """为单个设备创建子图
        
        Args:
            ax: matplotlib轴对象
            device: 设备名称
            device_results: 该设备的消融实验结果
            
        Returns:
            该设备的指标字典
        """
        # 获取数据集基线
        dataset_sample: List[ArchitectureInfo] = []
        dataset_pareto: List[ArchitectureInfo] = []
        try:
            dataset_result = self.metrics_calculator.compute_dataset_front(self.dataset, device)
            dataset_sample = dataset_result.dataset_architectures
            dataset_pareto = dataset_result.dataset_front
            logger.info(
                "数据集基线加载状态 - %s: sample=%d, pareto=%d",
                device,
                len(dataset_sample),
                len(dataset_pareto),
            )
        except Exception as exc:  # pragma: no cover - defensive
            logger.warning("加载设备 %s 基线数据失败: %s", device, exc)

        if dataset_sample:
            valid_dataset_sample = filter_valid_architectures(dataset_sample, hardware_device=device)
            if valid_dataset_sample:
                # 背景散点仍保持隐藏以避免视觉噪声
                pass

        valid_dataset_pareto = []
        if dataset_pareto:
            valid_dataset_pareto = filter_valid_architectures(dataset_pareto, hardware_device=device)
            if valid_dataset_pareto:
                pareto_acc = [arch.accuracy for arch in valid_dataset_pareto]
                pareto_lat = [arch.latency for arch in valid_dataset_pareto]
                mono_lat, mono_acc = self._build_monotone_2d_front(pareto_lat, pareto_acc)
                ax.plot(
                    mono_lat,
                    mono_acc,
                    color='#666666',
                    linewidth=2.8,
                    alpha=0.95,
                    zorder=9,
                    label='Theoretical Pareto (Dataset)'
                )

        # 存储该设备的指标
        device_metrics = {'baseline_pareto_size': len(valid_dataset_pareto)}
        
        # 绘制各种消融方法的结果
        for result in device_results:
            method_name = result.method_name
            if method_name not in self.method_styles:
                continue
                
            style = self.method_styles[method_name]
            
            # 过滤有效架构
            valid_architectures = [arch for arch in result.architectures if arch.is_valid]
            valid_pareto = [arch for arch in result.pareto_front if arch.is_valid]
            
            if not valid_pareto:
                logger.warning(f"方法 {method_name} 在设备 {device} 上没有有效的帕累托前沿")
                continue
            
            # 注释掉生成架构的散点绘制，只保留帕累托折线
            # all_acc = [arch.accuracy for arch in valid_architectures]
            # all_lat = [arch.latency for arch in valid_architectures]
            # 
            # if all_acc and all_lat:
            #     ax.scatter(all_lat, all_acc, c=style['color'], s=25, 
            #               alpha=0.6, zorder=style['zorder']-1, 
            #               edgecolors='white', linewidth=0.5)
            
            # 绘制帕累托前沿（投影至 2D 并强制单调）
            pareto_acc = [arch.accuracy for arch in valid_pareto]
            pareto_lat = [arch.latency for arch in valid_pareto]

            mono_lat, mono_acc = self._build_monotone_2d_front(pareto_lat, pareto_acc)

            # 绘制帕累托线，带有标记点
            ax.plot(mono_lat, mono_acc, 
                   color=style['color'], 
                   linestyle=style['linestyle'],
                   linewidth=style['linewidth'],
                   alpha=style['alpha'],
                   marker=style['marker'],
                   markersize=style['markersize'],
                   markerfacecolor=style['color'],
                   markeredgecolor='white',
                   markeredgewidth=1,
                   zorder=style['zorder'],
                   label=method_name)
            
            # 计算该方法的指标
            try:
                if valid_dataset_pareto:
                    method_metrics = evaluate_pareto_performance(
                        valid_dataset_pareto, valid_pareto
                    )
                    device_metrics[method_name] = method_metrics
                else:
                    device_metrics[method_name] = {
                        'hypervolume_ratio': 0.0,
                        'igd': float('inf'),
                        'generated_pareto_size': len(valid_pareto)
                    }
            except Exception as e:
                logger.warning(f"计算 {method_name} 在 {device} 上的指标失败: {e}")
                device_metrics[method_name] = {
                    'hypervolume_ratio': 0.0,
                    'igd': float('inf'),
                    'generated_pareto_size': len(valid_pareto)
                }
        
        # 设置子图样式
        self._style_subplot(ax, device)
        
        return device_metrics
    
    def _style_subplot(self, ax, device: str):
        """设置子图样式"""
        # 设置背景
        ax.set_facecolor('#f8f8f8')
        
        # 设置标签
        ax.set_xlabel('Latency (ms)', fontsize=20, fontweight='normal')
        
        # 设置网格
        ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5, color='white')
        
        # 设置坐标轴范围（按数据集定制，否则用默认）
        if self.dataset in self.dataset_axis_ranges:
            ymin, ymax = self.dataset_axis_ranges[self.dataset]
            ax.set_ylim(ymin, ymax)
        else:
            ax.set_ylim(40, 100)
        ax.set_xlim(left=0)   # 横轴起始为0
        
        # 设置刻度
        ax.tick_params(axis='both', which='major', labelsize=15)
        
        # 美化坐标轴
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_color('#cccccc')
        ax.spines['bottom'].set_color('#cccccc')
        
        # 添加设备名称
        device_name = self.device_names.get(device, device)
        ax.text(0.05, 0.95, device_name, transform=ax.transAxes, 
               fontsize=25, fontweight='bold', verticalalignment='top')
    
    def _create_global_legend(self, fig):
        """创建全局图例"""
        legend_elements = []
        
        # 数据集理论最优帕累托
        legend_elements.append(
            plt.Line2D([0], [0], color='#666666', linewidth=3, label='Theoretical Pareto (Dataset)')
        )
        
        # 各种方法
        for method_name, style in self.method_styles.items():
            legend_elements.append(
                plt.Line2D([0], [0], 
                          color=style['color'],
                          linestyle=style['linestyle'],
                          linewidth=style['linewidth'],
                          marker=style['marker'],
                          markersize=style['markersize'],
                          label=method_name)
            )
        
        fig.legend(handles=legend_elements, 
                  loc='lower center', 
                  bbox_to_anchor=(0.5, -0.02), 
                  ncol=5, 
                  fontsize=16, 
                  frameon=True, 
                  fancybox=False, 
                  shadow=False)

    def _build_monotone_2d_front(self, latencies: List[float], accuracies: List[float]):
        """从点集构建按延迟递增、准确率不下降的2D单调帕累托曲线"""
        if not latencies or not accuracies or len(latencies) != len(accuracies):
            return [], []

        # 去除NaN并按延迟升序排序
        pairs = [(l, a) for l, a in zip(latencies, accuracies) if l is not None and a is not None]
        pairs.sort(key=lambda x: x[0])

        mono_lat = []
        mono_acc = []
        best_acc = -float('inf')

        for l, a in pairs:
            if a >= best_acc:
                mono_lat.append(l)
                mono_acc.append(a)
                best_acc = a
            # 如果a < best_acc，则该点被2D支配，忽略以避免折线向下

        # 去重（防止重复的点导致标记堆叠）
        dedup_lat = []
        dedup_acc = []
        seen = set()
        for l, a in zip(mono_lat, mono_acc):
            key = (round(l, 6), round(a, 6))
            if key in seen:
                continue
            seen.add(key)
            dedup_lat.append(l)
            dedup_acc.append(a)

        return dedup_lat, dedup_acc
    
    def _generate_metrics_summary(self, all_metrics: Dict, output_dir: str) -> Dict:
        """生成指标总结
        
        Args:
            all_metrics: 所有设备的指标
            output_dir: 输出目录
            
        Returns:
            指标总结字典
        """
        logger.info("📊 生成消融实验指标总结...")
        
        # 收集所有方法的指标
        method_metrics = {}
        for method_name in self.method_styles.keys():
            method_metrics[method_name] = {
                'hv_ratios': [],
                'igd_values': [],
                'pareto_sizes': []
            }
        
        # 从各设备收集指标
        for device, device_metrics in all_metrics.items():
            for method_name in self.method_styles.keys():
                if method_name in device_metrics:
                    metrics = device_metrics[method_name]
                    method_metrics[method_name]['hv_ratios'].append(metrics.get('hypervolume_ratio', 0.0))
                    method_metrics[method_name]['igd_values'].append(metrics.get('igd', float('inf')))
                    method_metrics[method_name]['pareto_sizes'].append(metrics.get('generated_pareto_size', 0))
        
        # 计算平均指标
        summary = {}
        for method_name, metrics in method_metrics.items():
            hv_ratios = [hv for hv in metrics['hv_ratios'] if hv > 0]
            igd_values = [igd for igd in metrics['igd_values'] if igd != float('inf')]
            
            summary[method_name] = {
                'avg_hv_ratio': np.mean(hv_ratios) if hv_ratios else 0.0,
                'std_hv_ratio': np.std(hv_ratios) if hv_ratios else 0.0,
                'avg_igd': np.mean(igd_values) if igd_values else float('inf'),
                'std_igd': np.std(igd_values) if igd_values else 0.0,
                'avg_pareto_size': np.mean(metrics['pareto_sizes']) if metrics['pareto_sizes'] else 0.0,
                'valid_devices': len(hv_ratios)
            }
        
        # 保存指标总结
        summary_file = os.path.join(output_dir, f'ablation_metrics_summary_{self.dataset}.json')
        with open(summary_file, 'w', encoding='utf-8') as f:
            json.dump({
                'dataset': self.dataset,
                'method_summary': summary,
                'device_details': all_metrics
            }, f, indent=2, ensure_ascii=False)
        
        # 生成文本报告
        self._generate_text_report(summary, all_metrics, output_dir)
        
        logger.info(f"📊 指标总结已保存: {os.path.basename(summary_file)}")
        return summary
    
    def _generate_text_report(self, summary: Dict, all_metrics: Dict, output_dir: str):
        """生成文本格式的指标报告"""
        report_file = os.path.join(output_dir, f'ablation_report_{self.dataset}.txt')
        
        with open(report_file, 'w', encoding='utf-8') as f:
            f.write(f"Ablation Study Results Report - {self.dataset.upper()}\n")
            f.write("=" * 60 + "\n\n")
            
            f.write("OVERALL PERFORMANCE RANKING (by Average HV Ratio):\n")
            f.write("-" * 50 + "\n")
            
            # 按平均HV比值排序
            sorted_methods = sorted(summary.items(), 
                                  key=lambda x: x[1]['avg_hv_ratio'], 
                                  reverse=True)
            
            for rank, (method_name, metrics) in enumerate(sorted_methods, 1):
                f.write(f"{rank}. {method_name}:\n")
                f.write(f"   • Average HV Ratio: {metrics['avg_hv_ratio']:.4f} ± {metrics['std_hv_ratio']:.4f}\n")
                f.write(f"   • Average IGD: {metrics['avg_igd']:.6f} ± {metrics['std_igd']:.6f}\n")
                f.write(f"   • Average Pareto Size: {metrics['avg_pareto_size']:.1f}\n")
                f.write(f"   • Valid Devices: {metrics['valid_devices']}/{len(self.devices)}\n\n")
            
            f.write("DEVICE-SPECIFIC RESULTS:\n")
            f.write("-" * 50 + "\n")
            
            for device in self.devices:
                if device not in all_metrics:
                    continue
                    
                f.write(f"\n{device.upper()}:\n")
                device_metrics = all_metrics[device]
                
                for method_name in self.method_styles.keys():
                    if method_name in device_metrics:
                        metrics = device_metrics[method_name]
                        f.write(f"  {method_name}:\n")
                        f.write(f"    HV Ratio: {metrics.get('hypervolume_ratio', 0.0):.4f}\n")
                        f.write(f"    IGD: {metrics.get('igd', float('inf')):.6f}\n")
                        f.write(f"    Pareto Size: {metrics.get('generated_pareto_size', 0)}\n")
            
            f.write(f"\nKEY INSIGHTS:\n")
            f.write("-" * 50 + "\n")
            
            best_method = sorted_methods[0][0] if sorted_methods else "None"
            f.write(f"• Best Overall Method: {best_method}\n")
            
            if len(sorted_methods) >= 2:
                improvement = (sorted_methods[0][1]['avg_hv_ratio'] - 
                             sorted_methods[1][1]['avg_hv_ratio']) * 100
                f.write(f"• Performance Improvement: {improvement:.2f}% HV ratio over second-best\n")
            
            f.write(f"• Dataset: {self.dataset}\n")
            f.write(f"• Devices Tested: {len(self.devices)}\n")
            f.write(f"• Methods Compared: {len(self.method_styles)}\n")
        
        logger.info(f"📄 文本报告已保存: {os.path.basename(report_file)}")
