#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Architecture Analyzer for Convolution-based Classification
"""

import re
import logging
from typing import Dict, Tuple, Optional
from pel_nas.core.config import SEARCH_CONFIG

logger = logging.getLogger(__name__)

class ArchitectureAnalyzer:
    """Analyze architecture strings to determine convolution-based categories"""
    
    def __init__(self):
        """Initialize the analyzer"""
        self.conv_categories = SEARCH_CONFIG['conv_categories']
    
    def analyze_architecture(self, arch_str: str) -> Dict:
        """Analyze architecture string and return convolution information
        
        Args:
            arch_str: Architecture string in format |op0~0|+|op1~0|op2~1|+|op3~0|op4~1|op5~2|
            
        Returns:
            Dict containing convolution counts and category information
        """
        if not arch_str or not isinstance(arch_str, str):
            return self._empty_analysis()
        
        try:
            # Count convolution operations
            conv_3x3_count = arch_str.count('nor_conv_3x3')
            conv_1x1_count = arch_str.count('nor_conv_1x1')
            
            # Extract all operations for detailed analysis
            operations = self._extract_operations(arch_str)
            
            # Determine category
            category = self._determine_category(conv_3x3_count, conv_1x1_count)
            
            return {
                'arch_str': arch_str,
                'conv_3x3_count': conv_3x3_count,
                'conv_1x1_count': conv_1x1_count,
                'total_conv_count': conv_3x3_count + conv_1x1_count,
                'operations': operations,
                'category': category,
                'is_valid': category is not None,
                'category_info': self.conv_categories.get(category, {}) if category else {}
            }
            
        except Exception as e:
            logger.warning(f"Error analyzing architecture {arch_str}: {e}")
            return self._empty_analysis()
    
    def _extract_operations(self, arch_str: str) -> list:
        """Extract all operations from architecture string"""
        operations = []
        
        # Pattern to match operations: op_name~connection_index
        pattern = r'([a-zA-Z_0-9]+)~(\d+)'
        matches = re.findall(pattern, arch_str)
        
        for op_name, connection in matches:
            operations.append({
                'operation': op_name,
                'connection': int(connection)
            })
        
        return operations
    
    def _determine_category(self, conv_3x3_count: int, conv_1x1_count: int) -> Optional[str]:
        """Determine the category based on convolution counts"""
        
        for category, config in self.conv_categories.items():
            if self._matches_category(conv_3x3_count, conv_1x1_count, config):
                return category
        
        return None
    
    def _matches_category(self, conv_3x3_count: int, conv_1x1_count: int, config: Dict) -> bool:
        """Check if convolution counts match a specific category"""
        
        # Check 3x3 convolution count
        expected_3x3 = config['conv_3x3_count']
        if isinstance(expected_3x3, tuple):
            min_3x3, max_3x3 = expected_3x3
            if not (min_3x3 <= conv_3x3_count <= max_3x3):
                return False
        else:
            if conv_3x3_count != expected_3x3:
                return False
        
        # Check 1x1 convolution count
        expected_1x1 = config['conv_1x1_count']
        if isinstance(expected_1x1, tuple):
            min_1x1, max_1x1 = expected_1x1
            if not (min_1x1 <= conv_1x1_count <= max_1x1):
                return False
        else:    
            if conv_1x1_count != expected_1x1:
                return False
        
        return True
    
    def _empty_analysis(self) -> Dict:
        """Return empty analysis result"""
        return {
            'arch_str': '',
            'conv_3x3_count': 0,
            'conv_1x1_count': 0,
            'total_conv_count': 0,
            'operations': [],
            'category': None,
            'is_valid': False,
            'category_info': {}
        }
    
    def get_category_stats(self, architectures: list) -> Dict:
        """Get statistics for all categories from a list of architectures"""
        stats = {}
        
        # Initialize stats for all categories
        for category in self.conv_categories.keys():
            stats[category] = {
                'count': 0,
                'architectures': [],
                'avg_3x3_count': 0,
                'avg_1x1_count': 0
            }
        
        # Analyze each architecture
        for arch_str in architectures:
            analysis = self.analyze_architecture(arch_str)
            category = analysis['category']
            
            if category and category in stats:
                stats[category]['count'] += 1
                stats[category]['architectures'].append(arch_str)
        
        # Calculate averages
        for category, data in stats.items():
            if data['count'] > 0:
                total_3x3 = sum(self.analyze_architecture(arch)['conv_3x3_count'] 
                               for arch in data['architectures'])
                total_1x1 = sum(self.analyze_architecture(arch)['conv_1x1_count'] 
                               for arch in data['architectures'])
                
                data['avg_3x3_count'] = total_3x3 / data['count']
                data['avg_1x1_count'] = total_1x1 / data['count']
        
        return stats
    
    def validate_architecture_for_category(self, arch_str: str, target_category: str) -> bool:
        """Validate if architecture matches the target category"""
        analysis = self.analyze_architecture(arch_str)
        return analysis['category'] == target_category
    
    def get_category_description(self, category: str) -> str:
        """Get human-readable description of a category"""
        if category in self.conv_categories:
            return self.conv_categories[category]['description']
        return f"Unknown category: {category}"
    
    def get_all_categories(self) -> list:
        """Get list of all available categories"""
        return list(self.conv_categories.keys())

# Global analyzer instance
analyzer = ArchitectureAnalyzer() 
