#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Multi-Category Convolution-based Neural Architecture Searcher
"""

import json
import time
import logging
import os
from datetime import datetime
from typing import Dict, List
from dataclasses import asdict
from collections import defaultdict

from pel_nas.core.config import SEARCH_CONFIG
from pel_nas.llm.llm_client import MultiCategoryLLMClient
from pel_nas.search.evaluator import ArchitectureEvaluator, ArchitectureInfo
from pel_nas.visualization.visualizer import NASVisualizer

logger = logging.getLogger(__name__)

class ConvolutionParetoManager:
    """Manage Pareto fronts for each convolution category"""
    
    def __init__(self, hardware_device: str = None):
        """Initialize the Convolution-based Pareto Manager"""
        self.all_architectures = []
        self.category_pareto = {}
        self.hardware_device = hardware_device
        
        # Initialize with empty lists for each category
        for category in SEARCH_CONFIG['conv_categories'].keys():
            self.category_pareto[category] = []
    
    def update_pareto_fronts(self, all_architectures: List[ArchitectureInfo]):
        """Update convolution-category-based Pareto fronts"""
        logger.info("🔄 Updating convolution-based Pareto fronts...")
        
        self.all_architectures = all_architectures
        
        # Group by category
        category_groups = {}
        for category in SEARCH_CONFIG['conv_categories'].keys():
            category_groups[category] = []
        
        for arch in all_architectures:
            if arch.is_valid and arch.conv_category in category_groups:
                category_groups[arch.conv_category].append(arch)
        
        # Calculate Pareto front for each category
        for category, archs in category_groups.items():
            self.category_pareto[category] = self._calculate_pareto_front(archs)
        
        # Log results  
        logger.info("📊 Convolution-based Pareto front counts:")
        for category, archs in self.category_pareto.items():
            logger.info(f"   {category:15s}: {len(archs):2d} architectures")
    
    def _calculate_pareto_front(self, architectures: List[ArchitectureInfo]) -> List[ArchitectureInfo]:
        """Calculate robust Pareto front for architectures (filters invalid data)"""
        if not architectures:
            return []
        
        # 🔧 使用健壮的帕累托前沿计算，自动过滤异常值
        from pel_nas.search.evaluator import calculate_robust_pareto_front
        
        # 获取硬件设备信息
        hardware_device = None
        if hasattr(self, 'hardware_device'):
            hardware_device = self.hardware_device
        elif architectures and hasattr(architectures[0], 'hardware_device'):
            hardware_device = architectures[0].hardware_device
            
        return calculate_robust_pareto_front(architectures, hardware_device=hardware_device)
    
    def generate_feedback_by_category(self, iteration: int) -> Dict[str, str]:
        """Generate performance feedback for each category with reference architectures"""
        feedback = {}
        generation_instructions = {}
        
        # Get feedback configuration
        ref_count = SEARCH_CONFIG['feedback_config']['reference_architectures_count']
        gen_strategy = SEARCH_CONFIG['feedback_config']['generation_strategy']
        
        for category, config in SEARCH_CONFIG['conv_categories'].items():
            pareto_archs = self.category_pareto[category]
            
            if not pareto_archs:
                # No architectures yet - simple initial guidance
                feedback[category] = f"""No architectures found in {category} category yet.
Focus on constraints: {config['description']}"""
                
                generation_instructions[category] = f"""Generate {config['target_count']} diverse architectures exploring different operation combinations."""
                
            else:
                # Get reference architectures
                reference_archs = self._select_reference_architectures(
                    pareto_archs, ref_count, category
                )
                
                # Format reference architectures
                ref_text = []
                for i, arch in enumerate(reference_archs, 1):
                    ref_text.append(f"{i}. {arch.arch_str} | Acc: {arch.accuracy:.1f}% | Lat: {arch.latency:.3f}ms")
                
                feedback[category] = "\n".join(ref_text)
                
                # Generate instructions for mutation/crossover/generation
                generation_instructions[category] = self._generate_instructions(
                    reference_archs, gen_strategy
                )
        
        # Combine feedback with generation instructions using format_category_prompt
        combined_feedback = {}
        for category in feedback:
            combined_feedback[category] = {
                'performance_feedback': feedback[category],
                'generation_instructions': generation_instructions[category]
            }
        
        return combined_feedback

    def _select_reference_architectures(self, pareto_archs: List, ref_count: int, category: str) -> List:
        """Select reference architectures from Pareto front and best accuracy"""
        import random
        from pel_nas.search.evaluator import ArchitectureInfo
        
        # Get all architectures from this category
        all_category_archs = []
        for arch in self.all_architectures:
            if getattr(arch, 'conv_category', None) == category:
                all_category_archs.append(arch)
        
        if not all_category_archs:
            all_category_archs = pareto_archs
        
        reference_archs = []
        
        # First priority: random selection from Pareto front
        available_pareto = pareto_archs.copy()
        pareto_to_add = min(len(available_pareto), ref_count)
        if pareto_to_add > 0:
            selected_pareto = random.sample(available_pareto, pareto_to_add)
            reference_archs.extend(selected_pareto)
        
        # If we need more, add best accuracy architectures not already included
        if len(reference_archs) < ref_count:
            # Sort all category architectures by accuracy (highest first)
            best_accuracy_archs = sorted(all_category_archs, key=lambda x: x.accuracy, reverse=True)
            
            # Add highest accuracy architectures not already in reference_archs
            for arch in best_accuracy_archs:
                if len(reference_archs) >= ref_count:
                    break
                if arch not in reference_archs:
                    reference_archs.append(arch)
        
        return reference_archs[:ref_count]
    
    def _generate_instructions(self, reference_archs: List, gen_strategy: Dict) -> str:
        """Generate mutation/crossover/generation instructions"""
        mutation_count = gen_strategy['mutation_count']
        crossover_count = gen_strategy['crossover_count'] 
        pure_count = gen_strategy['pure_generation_count']
        
        instructions = []
        
        if mutation_count > 0:
            instructions.append(f"• Generate {mutation_count} architecture(s) by mutating one of the reference architectures above (change 1-2 operations)")
        
        if crossover_count > 0:
            instructions.append(f"• Generate {crossover_count} architecture(s) by combining features from two different reference architectures above")
        
        if pure_count > 0:
            instructions.append(f"• Generate {pure_count} completely new architecture(s) using novel operation combinations")
        
        return "\n".join(instructions)

class MultiCategoryNASSearcher:
    """Multi-category convolution-based neural architecture searcher"""
    
    def __init__(self, dataset: str = 'cifar10', hardware_device: str = 'edgegpu', use_predictor: bool = True):
        """Initialize the multi-category searcher"""
        logger.info("🚀 Initializing Multi-Category NAS Searcher...")
        logger.info(f"   Dataset: {dataset}")
        logger.info(f"   Hardware Device: {hardware_device}")
        logger.info(f"   Use Predictor: {use_predictor}")
        
        self.dataset = dataset
        self.hardware_device = hardware_device
        self.use_predictor = use_predictor
        
        # Initialize components
        self.evaluator = ArchitectureEvaluator(dataset, hardware_device, use_predictor)
        self.llm_client = MultiCategoryLLMClient(hardware_device=hardware_device)
        self.pareto_manager = ConvolutionParetoManager(hardware_device=hardware_device)
        self.visualizer = NASVisualizer(dataset, hardware_device)
        
        # Search state
        self.iteration_count = 0
        self.all_generated_architectures = []
        self.search_history = []
        
        logger.info("✅ Multi-Category NAS Searcher initialized successfully")
    
    def run_search(self, max_iterations: int = 10, output_dir: str = None) -> Dict:
        """Run multi-category convolution-based search"""
        logger.info(f"🎯 Starting multi-category search for {max_iterations} iterations")
        
        if output_dir is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_dir = f"outputs/search_{timestamp}"
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize with stratified architectures
        logger.info("🎲 Initializing with convolution-based stratified sampling...")
        initial_architectures = self.evaluator.select_initial_architectures()
        
        if not initial_architectures:
            raise ValueError("Failed to select initial architectures")
        
        self.all_generated_architectures.extend(initial_architectures)
        self.pareto_manager.update_pareto_fronts(self.all_generated_architectures)
        
        # Log initial statistics
        self._log_category_statistics(initial_architectures, "INITIAL")
        
        # Main search loop
        for iteration in range(1, max_iterations + 1):
            logger.info(f"🔄 Starting iteration {iteration}/{max_iterations}")
            self.iteration_count = iteration
            
            # Generate feedback for each category
            feedback = self.pareto_manager.generate_feedback_by_category(iteration)
            
            # Generate new architectures using multi-category parallel requests
            generated_by_category = self.llm_client.generate_all_categories(feedback)
            
            # Evaluate generated architectures
            evaluated_by_category = self.evaluator.evaluate_batch_by_category(generated_by_category)
            
            # Collect all new architectures
            new_architectures = []
            for category_archs in evaluated_by_category.values():
                new_architectures.extend(category_archs)
            
            if new_architectures:
                self.all_generated_architectures.extend(new_architectures)
                self.pareto_manager.update_pareto_fronts(self.all_generated_architectures)
                
                # Log iteration results
                self._log_category_statistics(new_architectures, f"ITERATION {iteration}")
                
                # Store iteration history
                self.search_history.append({
                    'iteration': iteration,
                    'generated_count': len(new_architectures),
                    'valid_count': sum(1 for arch in new_architectures if arch.is_valid),
                    'category_breakdown': {
                        category: len(archs) for category, archs in evaluated_by_category.items()
                    }
                })
            else:
                logger.warning(f"⚠️  No valid architectures generated in iteration {iteration}")
        
        # Generate final results and visualizations
        final_results = self._generate_final_results(output_dir)
        
        logger.info(f"🎉 Multi-category search completed!")
        logger.info(f"   Total architectures: {len(self.all_generated_architectures)}")
        logger.info(f"   Results saved to: {output_dir}")
        
        return final_results
    
    def _log_category_statistics(self, architectures: List[ArchitectureInfo], label: str):
        """Log statistics for each convolution category"""
        logger.info(f"📊 {label} CATEGORY STATISTICS:")
        
        category_counts = defaultdict(int)
        category_accuracies = defaultdict(list)
        
        for arch in architectures:
            if arch.is_valid:
                category_counts[arch.conv_category] += 1
                category_accuracies[arch.conv_category].append(arch.accuracy)
        
        for category, config in SEARCH_CONFIG['conv_categories'].items():
            count = category_counts[category]
            target = config['target_count']
            
            if count > 0:
                avg_acc = sum(category_accuracies[category]) / len(category_accuracies[category])
                logger.info(f"   {category:15s}: {count:2d}/{target:2d} (avg acc: {avg_acc:.1f}%)")
            else:
                logger.info(f"   {category:15s}: {count:2d}/{target:2d} (no valid archs)")
    
    def _generate_final_results(self, output_dir: str) -> Dict:
        """Generate final results and save visualizations"""
        logger.info("📊 Generating final results and visualizations...")
        
        # Prepare results data
        results_data = {
            'search_config': {
                'dataset': self.dataset,
                'hardware_device': self.hardware_device,
                'max_iterations': self.iteration_count,
                'total_architectures': len(self.all_generated_architectures),
                'convolution_categories': SEARCH_CONFIG['conv_categories']
            },
            'pareto_fronts': {},
            'search_history': self.search_history,
            'category_statistics': {},
            'all_architectures': []
        }
        
        # Collect Pareto front data
        total_pareto_count = 0
        for category, pareto_archs in self.pareto_manager.category_pareto.items():
            results_data['pareto_fronts'][category] = [asdict(arch) for arch in pareto_archs]
            total_pareto_count += len(pareto_archs)
            
            # Category statistics
            if pareto_archs:
                results_data['category_statistics'][category] = {
                    'pareto_count': len(pareto_archs),
                    'avg_accuracy': sum(arch.accuracy for arch in pareto_archs) / len(pareto_archs),
                    'avg_latency': sum(arch.latency for arch in pareto_archs) / len(pareto_archs),
                    'best_accuracy': max(arch.accuracy for arch in pareto_archs)
                }
        
        # All architectures data
        results_data['all_architectures'] = [asdict(arch) for arch in self.all_generated_architectures]
        
        # Save results
        results_file = os.path.join(output_dir, 'final_search_results.json')
        with open(results_file, 'w') as f:
            json.dump(results_data, f, indent=2)
        
        # Generate visualizations
        try:
            self.visualizer.create_all_visualizations(
                self.all_generated_architectures,
                self.pareto_manager.category_pareto,
                output_dir
            )
        except Exception as e:
            logger.error(f"❌ Visualization generation failed: {e}")
        
        # Generate summary report
        self._generate_summary_report(results_data, output_dir)
        
        # Client statistics
        client_stats = self.llm_client.get_stats()
        logger.info(f"📈 LLM Client Statistics:")
        logger.info(f"   Total requests: {client_stats['request_count']}")
        logger.info(f"   Success rate: {client_stats['success_rate']:.1f}%")
        
        results_data['llm_stats'] = client_stats
        
        return results_data
    
    def _generate_summary_report(self, results_data: Dict, output_dir: str):
        """Generate human-readable summary report"""
        report_file = os.path.join(output_dir, 'search_summary.md')
        
        with open(report_file, 'w') as f:
            f.write("# Multi-Category Convolution-based NAS Search Summary\n\n")
            
            # Search configuration
            config = results_data['search_config']
            f.write(f"## Search Configuration\n")
            f.write(f"- **Dataset**: {config['dataset']}\n")
            f.write(f"- **Hardware Device**: {config['hardware_device']}\n")
            f.write(f"- **Iterations**: {config['max_iterations']}\n")
            f.write(f"- **Total Architectures**: {config['total_architectures']}\n\n")
            
            # Category results
            f.write("## Category Results\n\n")
            for category, stats in results_data['category_statistics'].items():
                category_config = config['convolution_categories'][category]
                f.write(f"### {category.replace('_', ' ').title()}\n")
                f.write(f"- **Description**: {category_config['description']}\n")
                f.write(f"- **Target Count**: {category_config['target_count']}\n")
                f.write(f"- **Pareto Front Size**: {stats['pareto_count']}\n")
                f.write(f"- **Best Accuracy**: {stats['best_accuracy']:.2f}%\n")
                f.write(f"- **Average Accuracy**: {stats['avg_accuracy']:.2f}%\n")
                f.write(f"- **Average Latency**: {stats['avg_latency']:.3f}ms\n\n")
            
            # Search progress
            f.write("## Search Progress\n\n")
            for iteration_data in results_data['search_history']:
                iteration = iteration_data['iteration']
                f.write(f"### Iteration {iteration}\n")
                f.write(f"- Generated: {iteration_data['generated_count']} architectures\n")
                f.write(f"- Valid: {iteration_data['valid_count']} architectures\n")
                f.write("- Category breakdown:\n")
                for category, count in iteration_data['category_breakdown'].items():
                    f.write(f"  - {category}: {count}\n")
                f.write("\n") 
