import re
import itertools
import os
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass


@dataclass
class HeuristicComponent:
    """Heuristic function component"""
    condition: str  # Condition part
    action: str     # Action part
    parameters: Dict[str, str]  # Parameter configuration
    performance: Dict[str, float]  # Performance metrics


class HeuristicCombiner:
    """Heuristic function combiner - implements permutations among three functions"""
    
    def __init__(self):
        self.condition_patterns = {
            'restart_condition': r'else if \((.*?)\) restart\(\);',
            'bump_var_function': r'bump_var\((.*?)\);',
            'rephase_function': r'rephase\(\);'
        }
    
    def extract_top_k_results(self, results: Dict, k: int = 2) -> Dict[str, List[Tuple[str, Dict]]]:
        """Extract top-k best results for each function type from results"""
        # Sort by PAR-2, select best k
        sorted_results = sorted(
            [(global_id, data) for global_id, data in results.items() if global_id != "0"],
            key=lambda x: x[1].get('PAR-2', float('inf'))
        )
        
        # Group by function type
        function_types = {
            'restart_condition': [],
            'bump_var_function': [],
            'rephase_function': []
        }
        
        # Classify by actual task type
        # Assume results contain task type info, or infer from code content
        for global_id, result in sorted_results[:k*3]:  # Take more to ensure each type has enough
            code = result.get('prompt', '')
            
            # Infer function type from code content
            if 'restart();' in code:
                function_types['restart_condition'].append((global_id, result))
            elif 'bump_var(' in code:
                function_types['bump_var_function'].append((global_id, result))
            elif 'rephase();' in code:
                function_types['rephase_function'].append((global_id, result))
        
        # Take top-k for each type
        for func_type in function_types:
            function_types[func_type] = function_types[func_type][:k]
        
        return function_types
    
    def generate_all_permutations(self, top_k_results: Dict[str, List[Tuple[str, Dict]]]) -> List[Dict[str, str]]:
        """Generate all permutations among the three functions"""
        permutations = []
        
        restart_codes = [result[1].get('prompt', '') for result in top_k_results.get('restart_condition', [])]
        bump_codes = [result[1].get('prompt', '') for result in top_k_results.get('bump_var_function', [])]
        rephase_codes = [result[1].get('prompt', '') for result in top_k_results.get('rephase_function', [])]
        
        # Generate all permutations
        for restart_code in restart_codes:
            for bump_code in bump_codes:
                for rephase_code in rephase_codes:
                    permutation = {
                        'restart_condition': restart_code,
                        'bump_var_function': bump_code,
                        'rephase_function': rephase_code
                    }
                    permutations.append(permutation)
        
        return permutations
    
    def generate_solver_cpp(self, permutation: Dict[str, str], template_path: str, output_path: str):
        """Generate a complete solver.cpp file based on the permutation"""
        # Read template file
        with open(template_path, 'r') as f:
            template_content = f.read()
        
        # Replace placeholders in template
        modified_content = template_content
        
        # Replace restart condition
        if 'restart_condition' in permutation:
            modified_content = re.sub(
                r'else if \(lbd_queue_size == 50 && 0\.8 \* fast_lbd_sum / lbd_queue_size > slow_lbd_sum / conflicts\) restart\(\);',
                permutation['restart_condition'],
                modified_content
            )
        
        # Replace bump_var function
        if 'bump_var_function' in permutation:
            bump_pattern = r'void Solver::bump_var\(int var, double coeff\) \{[\s\S]*?\}'
            bump_replacement = f"void Solver::bump_var(int var, double coeff) {{\n    {permutation['bump_var_function']}\n}}"
            modified_content = re.sub(bump_pattern, bump_replacement, modified_content)
        
        # Replace rephase function
        if 'rephase_function' in permutation:
            rephase_pattern = r'void Solver::rephase\(\) \{[\s\S]*?\}'
            rephase_replacement = f"void Solver::rephase() {{\n    {permutation['rephase_function']}\n}}"
            modified_content = re.sub(rephase_pattern, rephase_replacement, modified_content)
        
        # Write new file
        with open(output_path, 'w') as f:
            f.write(modified_content)
        
        return output_path
    
    def generate_all_solver_combinations(self, results: Dict, template_path: str, output_dir: str) -> List[str]:
        """Generate all possible solver.cpp combinations"""
        # Extract top-k results
        top_k_results = self.extract_top_k_results(results, k=2)
        
        # Generate all permutations
        permutations = self.generate_all_permutations(top_k_results)
        
        # Generate all solver.cpp files
        generated_files = []
        for i, permutation in enumerate(permutations):
            output_path = os.path.join(output_dir, f"solver_combination_{i}.cpp")
            self.generate_solver_cpp(permutation, template_path, output_path)
            generated_files.append(output_path)
        
        return generated_files
    
    def parse_heuristic_code(self, code: str, task_type: str) -> HeuristicComponent:
        """Parse heuristic function code"""
        condition = ""
        action = ""
        parameters = {}
        
        if task_type == "restart_condition":
            # Parse restart condition
            match = re.search(self.condition_patterns['restart_condition'], code)
            if match:
                condition = match.group(1).strip()
                action = "restart();"
                
                # Extract parameters
                if "lbd_queue_size" in condition:
                    lbd_match = re.search(r'lbd_queue_size\s*==\s*(\d+)', condition)
                    if lbd_match:
                        parameters['lbd_queue_size'] = lbd_match.group(1)
                
                if "fast_lbd_sum" in condition and "slow_lbd_sum" in condition:
                    # Extract coefficient
                    coeff_match = re.search(r'(\d+\.?\d*)\s*\*\s*fast_lbd_sum', condition)
                    if coeff_match:
                        parameters['fast_coeff'] = coeff_match.group(1)
        
        elif task_type == "bump_var_function":
            # Parse bump_var function
            match = re.search(self.condition_patterns['bump_var_function'], code)
            if match:
                args = match.group(1).split(',')
                if len(args) >= 2:
                    var_expr = args[0].strip()
                    coeff_expr = args[1].strip()
                    action = f"bump_var({var_expr}, {coeff_expr});"
                    
                    # Extract coefficient
                    try:
                        coeff = float(coeff_expr)
                        parameters['bump_coeff'] = str(coeff)
                    except:
                        parameters['bump_coeff'] = coeff_expr
        
        elif task_type == "rephase_function":
            # Parse rephase function
            action = "rephase();"
            # Usually no complex parameters for rephase
        
        return HeuristicComponent(
            condition=condition,
            action=action,
            parameters=parameters,
            performance={}
        )
    
    def get_combination_summary(self, results: Dict) -> str:
        """Get combination summary information"""
        top_k_results = self.extract_top_k_results(results, k=2)
        permutations = self.generate_all_permutations(top_k_results)
        
        summary = f"Generated {len(permutations)} solver combinations:\n"
        summary += f"- Restart conditions: {len(top_k_results.get('restart_condition', []))}\n"
        summary += f"- Bump var functions: {len(top_k_results.get('bump_var_function', []))}\n"
        summary += f"- Rephase functions: {len(top_k_results.get('rephase_function', []))}\n"
        
        return summary 