#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Multi-Category LLM Client for Convolution-based Architecture Generation
"""

import asyncio
import aiohttp
import openai
import time
import re
import os
import logging
from typing import List, Dict, Optional, Tuple
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pel_nas.core.config import LLM_CONFIG, SEARCH_CONFIG, PROMPT_CONFIG

logger = logging.getLogger(__name__)

class MultiCategoryLLMClient:
    """LLM client for parallel multi-category architecture generation"""
    
    def __init__(self, hardware_device: str = 'edgegpu'):
        """Initialize multi-category LLM client"""
        self.api_key = LLM_CONFIG['api_key']
        self.model = LLM_CONFIG['model']
        self.temperature = LLM_CONFIG['temperature']
        self.max_tokens = LLM_CONFIG['max_tokens']
        self.max_retries = LLM_CONFIG['max_retries']
        self.hardware_device = hardware_device
        
        if not self.api_key:
            raise ValueError("Please set OPENAI_API_KEY environment variable")
        
        self.client = openai.OpenAI(api_key=self.api_key)
        self.request_count = 0
        self.error_count = 0
        
        # Load prompts for each category
        self.category_prompts = self._load_category_prompts()
        # Format system prompt with hardware device
        self.system_prompt = PROMPT_CONFIG['system_prompt'].format(hardware_device=hardware_device)
        
        logger.info(f"✅ Multi-category LLM client initialized for {hardware_device}")
        logger.info(f"   Loaded prompts for {len(self.category_prompts)} categories")
    
    def _load_category_prompts(self) -> Dict[str, str]:
        """Load prompt templates for each category"""
        prompts = {}
        prompt_dir = Path(PROMPT_CONFIG['prompt_dir'])
        
        for category, config in SEARCH_CONFIG['conv_categories'].items():
            prompt_file = prompt_dir / config['prompt_file']
            
            try:
                with open(prompt_file, 'r', encoding='utf-8') as f:
                    prompts[category] = f.read().strip()
                logger.debug(f"Loaded prompt for {category}: {prompt_file}")
            except Exception as e:
                logger.error(f"Failed to load prompt for {category}: {e}")
                prompts[category] = self._get_fallback_prompt(category)
        
        return prompts
    
    def _get_fallback_prompt(self, category: str) -> str:
        """Get fallback prompt if file loading fails"""
        config = SEARCH_CONFIG['conv_categories'][category]
        return f"""Generate architectures for the {category} category.
Description: {config['description']}
Generate EXACTLY {config['target_count']} architectures.
Return ONLY architecture strings, one per line."""
    
    def generate_all_categories(self, performance_feedback: Dict[str, str] = None) -> Dict[str, List[str]]:
        """Generate architectures for all categories in parallel"""
        logger.info("🚀 Starting parallel multi-category generation...")
        
        if performance_feedback is None:
            performance_feedback = {category: "No previous feedback available." 
                                  for category in SEARCH_CONFIG['conv_categories'].keys()}
        
        # Prepare requests for each category
        category_requests = []
        for category, config in SEARCH_CONFIG['conv_categories'].items():
            user_prompt = self._format_category_prompt(
                category, 
                config['target_count'], 
                performance_feedback.get(category, "No feedback available.")
            )
            category_requests.append((category, user_prompt))
        
        # Execute parallel requests
        if LLM_CONFIG.get('concurrent_requests', True):
            results = self._generate_parallel(category_requests)
        else:
            results = self._generate_sequential(category_requests)
        
        # Parse results
        parsed_results = {}
        total_architectures = 0
        
        for category, response in results.items():
            if response:
                arch_strings = self.parse_architecture_strings(response)
                parsed_results[category] = arch_strings
                total_architectures += len(arch_strings)
                logger.info(f"   {category:15s}: {len(arch_strings):2d} architectures")
            else:
                parsed_results[category] = []
                logger.warning(f"   {category:15s}: 0 architectures (failed)")
        
        logger.info(f"🎯 Total generated: {total_architectures} architectures across {len(parsed_results)} categories")
        return parsed_results
    
    def _format_category_prompt(self, category: str, target_count: int, feedback) -> str:
        """Format prompt for specific category"""
        template = self.category_prompts[category]
        
        # Handle both old string format and new dict format
        if isinstance(feedback, dict):
            performance_feedback = feedback.get('performance_feedback', 'No feedback available.')
            generation_instructions = feedback.get('generation_instructions', 'Generate diverse architectures.')
        else:
            # Fallback for old string format
            performance_feedback = feedback
            generation_instructions = 'Generate diverse architectures.'
        
        return template.format(
            target_count=target_count,
            performance_feedback=performance_feedback,
            generation_instructions=generation_instructions
        )
    
    def _generate_parallel(self, category_requests: List[Tuple[str, str]]) -> Dict[str, str]:
        """Generate architectures for all categories in parallel using ThreadPoolExecutor"""
        results = {}
        
        with ThreadPoolExecutor(max_workers=6) as executor:
            # Submit all requests
            future_to_category = {
                executor.submit(self._generate_single_category, category, prompt): category
                for category, prompt in category_requests
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_category):
                category = future_to_category[future]
                try:
                    response = future.result(timeout=120)  # 2 minute timeout per request
                    results[category] = response
                    logger.info(f"✅ Completed generation for {category}")
                except Exception as e:
                    logger.error(f"❌ Failed generation for {category}: {e}")
                    results[category] = ""
        
        return results
    
    def _generate_sequential(self, category_requests: List[Tuple[str, str]]) -> Dict[str, str]:
        """Generate architectures for all categories sequentially"""
        results = {}
        
        for category, prompt in category_requests:
            try:
                response = self._generate_single_category(category, prompt)
                results[category] = response
                logger.info(f"✅ Completed generation for {category}")
            except Exception as e:
                logger.error(f"❌ Failed generation for {category}: {e}")
                results[category] = ""
        
        return results
    
    def _generate_single_category(self, category: str, user_prompt: str) -> str:
        """Generate architectures for a single category"""
        logger.debug(f"Generating architectures for {category}")
        
        for attempt in range(self.max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": self.system_prompt},
                        {"role": "user", "content": user_prompt}
                    ],
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                    top_p=0.9
                )
                
                self.request_count += 1
                response_text = response.choices[0].message.content.strip()
                
                logger.debug(f"Successfully generated for {category} (attempt {attempt + 1})")
                return response_text
                
            except Exception as e:
                self.error_count += 1
                logger.error(f"Generation failed for {category} (attempt {attempt + 1}/{self.max_retries}): {e}")
                
                if attempt == self.max_retries - 1:
                    raise Exception(f"Failed to generate for {category} after {self.max_retries} attempts: {e}")
                
                time.sleep(2 ** attempt)  # Exponential backoff
        
        return ""
    
    def parse_architecture_strings(self, response: str) -> List[str]:
        """Parse architecture strings from LLM response"""
        arch_strings = []
        lines = response.split('\n')
        
        for line in lines:
            line = line.strip()
            # Look for complete architecture strings
            if '|' in line and '~' in line and '+' in line:
                # NAS-Bench-201 format: |op0~0|+|op1~0|op2~1|+|op3~0|op4~1|op5~2|
                arch_match = re.search(r'\|[^|]+~\d+\|\+\|[^|]+~\d+\|[^|]+~\d+\|\+\|[^|]+~\d+\|[^|]+~\d+\|[^|]+~\d+\|', line)
                if arch_match:
                    arch_str = arch_match.group(0)
                    if self._validate_architecture_format(arch_str):
                        arch_strings.append(arch_str)
                        logger.debug(f"Parsed valid architecture: {arch_str}")
                else:
                    # Try to find partial matches and report
                    partial_match = re.search(r'\|[^|]*\|(?:\+\|[^|]*\|)*\+?', line)
                    if partial_match:
                        logger.warning(f"Found incomplete architecture string: {partial_match.group(0)}")
        
        logger.debug(f"Parsed {len(arch_strings)} valid architecture strings from response")
        return arch_strings
    
    def _validate_architecture_format(self, arch_str: str) -> bool:
        """Validate NAS-Bench-201 architecture string format"""
        if not arch_str.startswith('|') or not arch_str.endswith('|'):
            return False
        
        # NAS-Bench-201 format: |op0~0|+|op1~0|op2~1|+|op3~0|op4~1|op5~2|
        parts = arch_str.split('+')
        if len(parts) != 3:
            return False
        
        # Check each stage
        valid_ops = ['none', 'avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'skip_connect']
        
        # Stage 1: should have 1 operation |op0~0|
        stage1 = parts[0]
        if not self._validate_stage(stage1, 1, valid_ops):
            return False
        
        # Stage 2: should have 2 operations |op1~0|op2~1|
        stage2 = parts[1]
        if not stage2.startswith('|') or not stage2.endswith('|'):
            return False
        if not self._validate_stage(stage2, 2, valid_ops):
            return False
            
        # Stage 3: should have 3 operations |op3~0|op4~1|op5~2|
        stage3 = parts[2]
        if not self._validate_stage(stage3, 3, valid_ops):
            return False
        
        return True
    
    def _validate_stage(self, stage: str, expected_ops: int, valid_ops: List[str]) -> bool:
        """Validate a single stage of the architecture"""
        if not stage.startswith('|') or not stage.endswith('|'):
            return False
        
        # Remove leading and trailing |
        stage_content = stage[1:-1]
        
        # Split by | to get operations
        ops = stage_content.split('|')
        
        if len(ops) != expected_ops:
            return False
        
        # Validate each operation
        for i, op in enumerate(ops):
            if '~' not in op:
                return False
            
            op_name, connection = op.split('~')
            
            # Validate operation name
            if op_name not in valid_ops:
                return False
            
            # Validate connection index
            try:
                conn_idx = int(connection)
                if conn_idx < 0 or conn_idx > i:
                    return False
            except ValueError:
                return False
        
        return True
    
    def get_stats(self) -> Dict:
        """Get client statistics"""
        return {
            'request_count': self.request_count,
            'error_count': self.error_count,
            'success_rate': (self.request_count - self.error_count) / max(self.request_count, 1) * 100,
            'categories_loaded': len(self.category_prompts)
        }

# Legacy compatibility - create alias for backward compatibility
LLMClient = MultiCategoryLLMClient 
