#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Multi-Category Convolution-based Visualization module for NAS search results
"""

import os
import json
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Dict, Tuple
from pel_nas.search.evaluator import ArchitectureInfo, filter_valid_architectures, get_robust_axis_limits, calculate_robust_pareto_front
from pel_nas.core.config import SEARCH_CONFIG
from pel_nas.data.pareto_metrics import ParetoMetrics

logger = logging.getLogger(__name__)

class NASVisualizer:
    """Multi-category convolution-based visualizer for NAS search results"""
    
    def __init__(self, dataset: str = 'cifar10', hardware_device: str = 'edgegpu'):
        """Initialize visualizer
        
        Args:
            dataset: Dataset being used
            hardware_device: Hardware device being used
        """
        self.dataset = dataset
        self.hardware_device = hardware_device
        
        self.metrics = ParetoMetrics()
        
        # Set matplotlib parameters for better plots
        plt.rcParams['figure.figsize'] = (12, 8)
        plt.rcParams['font.size'] = 11
        plt.rcParams['axes.labelsize'] = 12
        plt.rcParams['axes.titlesize'] = 14
        plt.rcParams['legend.fontsize'] = 10
        
        # Color scheme for convolution categories
        self.category_colors = {
            'no_convolution': '#FF6B6B',      # Red
            'pure_lightweight': '#4ECDC4',   # Teal
            'basic_mixed': '#45B7D1',        # Blue
            'medium_mixed': '#96CEB4',       # Green
            'heavy_mixed': '#FECA57',        # Yellow
            'extreme_compute': '#FF9FF3'     # Pink
        }
        
        # Markers for categories
        self.category_markers = {
            'no_convolution': 'o',
            'pure_lightweight': 's',
            'basic_mixed': '^',
            'medium_mixed': 'D',
            'heavy_mixed': 'v',
            'extreme_compute': '*'
        }
        
    def create_all_visualizations(self, 
                                all_architectures: List[ArchitectureInfo],
                                category_pareto_fronts: Dict[str, List[ArchitectureInfo]],
                                output_dir: str):
        """Create all visualizations for the search results"""
        
        logger.info("🎨 Creating multi-category visualizations...")
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        try:
            # 1. Convolution-based Pareto comparison with dataset baseline
            pareto_plot_path = self.create_convolution_pareto_plot(
                all_architectures, category_pareto_fronts, output_dir
            )
            
            # 2. Category distribution analysis
            distribution_plot_path = self.create_category_distribution_plot(
                all_architectures, output_dir
            )
            
            # 3. Performance evolution plot
            evolution_plot_path = self.create_performance_evolution_plot(
                category_pareto_fronts, output_dir
            )
            
            # 4. Convolution pattern analysis
            pattern_plot_path = self.create_convolution_pattern_plot(
                all_architectures, output_dir
            )
            
            logger.info("✅ All visualizations created successfully")
            return {
                'pareto_plot': pareto_plot_path,
                'distribution_plot': distribution_plot_path,
                'evolution_plot': evolution_plot_path,
                'pattern_plot': pattern_plot_path
            }
            
        except Exception as e:
            logger.error(f"❌ Visualization creation failed: {e}")
            return {}
    
    def create_convolution_pareto_plot(self, 
                                     all_architectures: List[ArchitectureInfo],
                                     category_pareto_fronts: Dict[str, List[ArchitectureInfo]],
                                     output_dir: str) -> str:
        """Create Pareto front plot with dataset baseline comparison"""
        
        logger.info("🎨 Creating Pareto comparison with dataset baseline...")
        
        fig, ax = plt.subplots(1, 1, figsize=(14, 10))
        
        dataset_archs: List[ArchitectureInfo] = []
        dataset_pareto: List[ArchitectureInfo] = []
        try:
            dataset_result = self.metrics.compute_dataset_front(self.dataset, self.hardware_device)
            dataset_archs = dataset_result.dataset_architectures
            dataset_pareto = dataset_result.dataset_front
        except Exception as exc:  # pragma: no cover - defensive
            logger.warning("Failed to compute dataset baseline: %s", exc)

        if dataset_archs:
            valid_dataset_archs = filter_valid_architectures(dataset_archs, hardware_device=self.hardware_device)
            logger.info(
                "🧹 数据集架构过滤: %d -> %d 个有效架构",
                len(dataset_archs),
                len(valid_dataset_archs),
            )

            if valid_dataset_archs:
                dataset_acc = [arch.accuracy for arch in valid_dataset_archs]
                dataset_lat = [arch.latency for arch in valid_dataset_archs]
                ax.scatter(
                    dataset_acc,
                    dataset_lat,
                    c='lightgray',
                    alpha=0.2,
                    s=8,
                    label=f'Dataset Sample ({len(valid_dataset_archs)} archs)'
                )

        if dataset_pareto:
            valid_dataset_pareto = filter_valid_architectures(dataset_pareto, hardware_device=self.hardware_device)
            logger.info(
                "🧹 数据集帕累托过滤: %d -> %d 个有效架构",
                len(dataset_pareto),
                len(valid_dataset_pareto),
            )

            if valid_dataset_pareto:
                pareto_acc = [arch.accuracy for arch in valid_dataset_pareto]
                pareto_lat = [arch.latency for arch in valid_dataset_pareto]

                sorted_dataset_pareto = sorted(valid_dataset_pareto, key=lambda x: x.accuracy)
                sorted_acc = [arch.accuracy for arch in sorted_dataset_pareto]
                sorted_lat = [arch.latency for arch in sorted_dataset_pareto]

                ax.scatter(
                    pareto_acc,
                    pareto_lat,
                    c='navy',
                    s=50,
                    alpha=0.8,
                    marker='x',
                    label=f'Dataset Pareto ({len(valid_dataset_pareto)} archs)'
                )
                ax.plot(
                    sorted_acc,
                    sorted_lat,
                    'navy',
                    linewidth=2,
                    alpha=0.6,
                    linestyle='--',
                    label='Dataset Pareto Line'
                )
        
        # 🔧 过滤生成架构中的异常值
        if all_architectures:
            valid_generated = filter_valid_architectures(all_architectures, hardware_device=self.hardware_device)
            logger.info(f"🧹 生成架构过滤: {len(all_architectures)} -> {len(valid_generated)} 个有效架构")
            
            if valid_generated:
                all_acc = [arch.accuracy for arch in valid_generated]
                all_lat = [arch.latency for arch in valid_generated]
                ax.scatter(all_acc, all_lat, 
                          c='orange', alpha=0.4, s=25, 
                          label=f'LLM Generated ({len(all_acc)} archs)')
        
        # 🔧 过滤类别帕累托前沿中的异常值
        for category, pareto_archs in category_pareto_fronts.items():
            if pareto_archs:
                valid_pareto_archs = filter_valid_architectures(pareto_archs, hardware_device=self.hardware_device)
                logger.info(f"🧹 {category}帕累托过滤: {len(pareto_archs)} -> {len(valid_pareto_archs)} 个有效架构")
                
                if valid_pareto_archs:
                    config = SEARCH_CONFIG['conv_categories'][category]
                    color = self.category_colors.get(category, 'black')
                    marker = self.category_markers.get(category, 'o')
                    
                    pareto_acc = [arch.accuracy for arch in valid_pareto_archs]
                    pareto_lat = [arch.latency for arch in valid_pareto_archs]
                    
                    ax.scatter(pareto_acc, pareto_lat,
                              c=color, marker=marker, s=100, alpha=0.9, 
                              edgecolor='black', linewidth=1,
                              label=f'{category.replace("_", " ").title()} Pareto ({len(valid_pareto_archs)})')
        
        # 🔧 计算和绘制健壮的整体帕累托前沿
        if all_architectures:
            valid_generated = filter_valid_architectures(all_architectures, hardware_device=self.hardware_device)
            if valid_generated:
                # 使用健壮的帕累托前沿计算（内部已包含过滤）
                overall_pareto = calculate_robust_pareto_front(valid_generated, hardware_device=self.hardware_device)
                
                if overall_pareto:
                    sorted_overall_pareto = sorted(overall_pareto, key=lambda x: x.accuracy)
                    overall_acc = [arch.accuracy for arch in sorted_overall_pareto]
                    overall_lat = [arch.latency for arch in sorted_overall_pareto]
                    
                    ax.plot(overall_acc, overall_lat, 'red', linewidth=3, alpha=0.9,
                           label=f'Generated Overall Pareto ({len(overall_pareto)})')
        
        # Formatting
        ax.set_xlabel('Accuracy (%)')
        ax.set_ylabel('Latency (ms)')
        ax.set_title(f'Multi-Category Convolution-based NAS Results\n'
                    f'Dataset: {self.dataset.upper()}, Hardware: {self.hardware_device.upper()}\n'
                    f'Dataset Baseline vs LLM Generated Architectures')
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
        ax.grid(True, alpha=0.3)
        
        # 🔧 设置健壮的坐标轴限制（自动过滤异常值）
        all_valid_archs = []
        
        # 收集所有有效架构用于坐标轴限制计算
        try:
            if dataset_archs:
                all_valid_archs.extend(filter_valid_architectures(dataset_archs, hardware_device=self.hardware_device))
        except:
            pass
        
        if all_architectures:
            all_valid_archs.extend(filter_valid_architectures(all_architectures, hardware_device=self.hardware_device))
        
        if all_valid_archs:
            acc_range, lat_range = get_robust_axis_limits(all_valid_archs, hardware_device=self.hardware_device)
            ax.set_xlim(acc_range[0], acc_range[1])
            ax.set_ylim(lat_range[0], lat_range[1])
            logger.info(f"📊 应用健壮坐标轴范围: 准确率{acc_range}, 延迟{lat_range}")
        else:
            # 默认范围
            ax.set_xlim(0, 100)
            ax.set_ylim(0, 10)
            logger.warning("⚠️  没有有效数据，使用默认坐标轴范围")
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(output_dir, 'convolution_pareto_comparison.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"📊 Convolution Pareto plot saved: {os.path.basename(plot_path)}")
        return plot_path
    
    def _calculate_overall_pareto_front(self, architectures: List[ArchitectureInfo]) -> List[ArchitectureInfo]:
        """Calculate robust overall Pareto front for all generated architectures"""
        if not architectures:
            return []
        
        # 🔧 使用健壮的帕累托前沿计算
        return calculate_robust_pareto_front(architectures, hardware_device=self.hardware_device)
    
    def create_category_distribution_plot(self, 
                                        all_architectures: List[ArchitectureInfo],
                                        output_dir: str) -> str:
        """Create distribution plot showing architecture counts by category"""
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # Count architectures by category
        category_counts = {category: 0 for category in SEARCH_CONFIG['conv_categories'].keys()}
        category_accuracies = {category: [] for category in SEARCH_CONFIG['conv_categories'].keys()}
        
        for arch in all_architectures:
            if arch.is_valid and arch.conv_category in category_counts:
                category_counts[arch.conv_category] += 1
                category_accuracies[arch.conv_category].append(arch.accuracy)
        
        # Top subplot: Category counts
        categories = list(category_counts.keys())
        counts = list(category_counts.values())
        targets = [SEARCH_CONFIG['conv_categories'][cat]['target_count'] for cat in categories]
        
        x_pos = np.arange(len(categories))
        colors = [self.category_colors.get(cat, 'gray') for cat in categories]
        
        bars1 = ax1.bar(x_pos - 0.2, counts, 0.4, label='Generated', color=colors, alpha=0.7)
        bars2 = ax1.bar(x_pos + 0.2, targets, 0.4, label='Target', color=colors, alpha=0.3, edgecolor='black')
        
        ax1.set_xlabel('Convolution Categories')
        ax1.set_ylabel('Architecture Count')
        ax1.set_title('Architecture Generation by Convolution Category')
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels([cat.replace('_', '\n') for cat in categories], rotation=0, ha='center')
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='y')
        
        # Add count labels on bars
        for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
            ax1.text(bar1.get_x() + bar1.get_width()/2, bar1.get_height() + 0.1, 
                    str(counts[i]), ha='center', va='bottom', fontweight='bold')
            ax1.text(bar2.get_x() + bar2.get_width()/2, bar2.get_height() + 0.1, 
                    str(targets[i]), ha='center', va='bottom', fontweight='bold')
        
        # Bottom subplot: Accuracy distribution by category
        valid_categories = [cat for cat in categories if category_accuracies[cat]]
        if valid_categories:
            box_data = [category_accuracies[cat] for cat in valid_categories]
            box_colors = [self.category_colors.get(cat, 'gray') for cat in valid_categories]
            
            bp = ax2.boxplot(box_data, labels=[cat.replace('_', '\n') for cat in valid_categories], 
                           patch_artist=True)
            
            for patch, color in zip(bp['boxes'], box_colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        
        ax2.set_xlabel('Convolution Categories')
        ax2.set_ylabel('Accuracy (%)')
        ax2.set_title('Accuracy Distribution by Category')
        ax2.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(output_dir, 'category_distribution_analysis.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"📊 Category distribution plot saved: {os.path.basename(plot_path)}")
        return plot_path
    
    def create_performance_evolution_plot(self, 
                                        category_pareto_fronts: Dict[str, List[ArchitectureInfo]],
                                        output_dir: str) -> str:
        """Create performance evolution plot showing best architectures in each category"""
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        categories = list(SEARCH_CONFIG['conv_categories'].keys())
        
        # Prepare data
        category_labels = [cat.replace('_', ' ').title() for cat in categories]
        best_accuracies = []
        avg_latencies = []
        # 不再统计能耗
        category_colors_list = []
        
        for category in categories:
            pareto_archs = category_pareto_fronts.get(category, [])
            color = self.category_colors.get(category, 'gray')
            category_colors_list.append(color)
            
            if pareto_archs:
                best_accuracies.append(max(arch.accuracy for arch in pareto_archs))
                avg_latencies.append(sum(arch.latency for arch in pareto_archs) / len(pareto_archs))
            else:
                best_accuracies.append(0)
                avg_latencies.append(0)
        
        # Left plot: Best accuracy by category
        bars1 = ax1.bar(range(len(categories)), best_accuracies, color=category_colors_list, alpha=0.7)
        ax1.set_xlabel('Convolution Categories')
        ax1.set_ylabel('Best Accuracy (%)')
        ax1.set_title('Best Accuracy by Category')
        ax1.set_xticks(range(len(categories)))
        ax1.set_xticklabels(category_labels, rotation=45, ha='right')
        ax1.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar, acc in zip(bars1, best_accuracies):
            if acc > 0:
                ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                        f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
        
        # Right plot: Average latency by category (bar chart)
        bars2 = ax2.bar(range(len(categories)), avg_latencies, color=category_colors_list, alpha=0.7)
        ax2.set_xlabel('Convolution Categories')
        ax2.set_ylabel('Average Latency (ms)')
        ax2.set_title('Average Latency by Category')
        ax2.set_xticks(range(len(categories)))
        ax2.set_xticklabels(category_labels, rotation=45, ha='right')
        ax2.grid(True, alpha=0.3, axis='y')
        for bar, lat in zip(bars2, avg_latencies):
            if lat > 0:
                ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + (0.02*bar.get_height() if bar.get_height()>0 else 0.02),
                         f'{lat:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(output_dir, 'performance_evolution.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"📊 Performance evolution plot saved: {os.path.basename(plot_path)}")
        return plot_path
    
    def create_convolution_pattern_plot(self, 
                                      all_architectures: List[ArchitectureInfo],
                                      output_dir: str) -> str:
        """Create plot showing convolution patterns across architectures"""
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # Collect convolution data
        conv_3x3_counts = []
        conv_1x1_counts = []
        categories = []
        accuracies = []
        
        for arch in all_architectures:
            if arch.is_valid:
                conv_3x3_counts.append(arch.conv_3x3_count)
                conv_1x1_counts.append(arch.conv_1x1_count)
                categories.append(arch.conv_category)
                accuracies.append(arch.accuracy)
        
        # Left plot: 3x3 vs 1x1 convolution counts
        category_set = list(set(categories))
        for category in category_set:
            if category in self.category_colors:
                mask = [cat == category for cat in categories]
                conv_3x3_cat = [conv_3x3_counts[i] for i in range(len(mask)) if mask[i]]
                conv_1x1_cat = [conv_1x1_counts[i] for i in range(len(mask)) if mask[i]]
                
                ax1.scatter(conv_3x3_cat, conv_1x1_cat, 
                           c=self.category_colors[category], 
                           marker=self.category_markers.get(category, 'o'),
                           s=60, alpha=0.7, 
                           label=category.replace('_', ' ').title())
        
        ax1.set_xlabel('3x3 Convolution Count')
        ax1.set_ylabel('1x1 Convolution Count')
        ax1.set_title('Convolution Pattern Distribution')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        if conv_3x3_counts and conv_1x1_counts:
            ax1.set_xlim(-0.5, max(conv_3x3_counts) + 0.5)
            ax1.set_ylim(-0.5, max(conv_1x1_counts) + 0.5)
        
        # Right plot: Total convolutions vs accuracy
        total_convs = [conv_3x3_counts[i] + conv_1x1_counts[i] for i in range(len(conv_3x3_counts))]
        
        for category in category_set:
            if category in self.category_colors:
                mask = [cat == category for cat in categories]
                total_convs_cat = [total_convs[i] for i in range(len(mask)) if mask[i]]
                accuracies_cat = [accuracies[i] for i in range(len(mask)) if mask[i]]
                
                ax2.scatter(total_convs_cat, accuracies_cat,
                           c=self.category_colors[category],
                           marker=self.category_markers.get(category, 'o'),
                           s=60, alpha=0.7,
                           label=category.replace('_', ' ').title())
        
        ax2.set_xlabel('Total Convolution Count')
        ax2.set_ylabel('Accuracy (%)')
        ax2.set_title('Convolution Count vs Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(output_dir, 'convolution_pattern_analysis.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"📊 Convolution pattern plot saved: {os.path.basename(plot_path)}")
        return plot_path 
