#!/usr/bin/env python3
"""
Generate heuristic combinations from multiple search results
Selects top1 and baseline for each function type
"""

import os
import sys
import json
import re
import argparse
from datetime import datetime

def find_latest_search_folders() -> dict:
    """Find the latest search folders for all three function types
    Supports multiple formats: 
    - Old format: task_search_date
    - Unified format: task_session_date_id  
    - Current format: task_results_date
    All formats require top_k_results/ subdirectory"""
    search_folders = {}
    
    for task in ["restart_condition", "bump_var_function", "rephase_function"]:
        folders = []
        for folder in os.listdir('.'):
            # Check for all possible formats: old, new unified, and current results format
            if ((folder.startswith(f"{task}_search_") or 
                 folder.startswith(f"{task}_session_") or 
                 folder.startswith(f"{task}_results_")) 
                and os.path.isdir(folder)):
                
                # Check if folder has top_k_results (required for all formats)
                top_k_path = os.path.join(folder, "top_k_results")
                if os.path.exists(top_k_path):
                    folders.append(folder)
                    print(f"  Found valid folder: {folder}")
                else:
                    print(f"  Skipping {folder} (no top_k_results/)")
        
        if folders:
            # Sort by timestamp and use the latest
            folders.sort(reverse=True)
            search_folders[task] = folders[0]
            print(f"Found latest {task} search folder: {folders[0]}")
        else:
            print(f"Warning: No search folder found for {task}")
    
    return search_folders

def load_top_k_results(search_folder: str) -> dict:
    """Load top-k results and baseline from search folder"""
    top_k_folder = os.path.join(search_folder, "top_k_results")
    if not os.path.exists(top_k_folder):
        raise FileNotFoundError(f"Top-k results folder not found: {top_k_folder}")
    
    results = {}
    
    # Load baseline first
    baseline_file = os.path.join(top_k_folder, "baseline_id_0.json")
    if os.path.exists(baseline_file):
        with open(baseline_file, 'r') as f:
            baseline_data = json.load(f)
            global_id = baseline_data['global_id']
            results[global_id] = {
                'prompt': baseline_data['prompt'],
                'time': baseline_data['time'],
                'PAR-2': baseline_data['PAR-2'],
                'extra_params': baseline_data.get('extra_params', {}),
                'task_description': baseline_data.get('task_description', ''),
                'modification_direction': baseline_data.get('modification_direction', ''),
                'is_baseline': True,
                'rank': baseline_data.get('rank', 0)
            }
        print(f"  Loaded baseline (ID: {global_id})")
    
    # Load individual result files (top-k results)
    for filename in os.listdir(top_k_folder):
        if filename.startswith("top_") and filename.endswith(".json"):
            result_file = os.path.join(top_k_folder, filename)
            with open(result_file, 'r') as f:
                result_data = json.load(f)
                global_id = result_data['global_id']
                results[global_id] = {
                    'prompt': result_data['prompt'],
                    'time': result_data['time'],
                    'PAR-2': result_data['PAR-2'],
                    'extra_params': result_data.get('extra_params', {}),
                    'task_description': result_data.get('task_description', ''),
                    'modification_direction': result_data.get('modification_direction', ''),
                    'is_baseline': False,
                    'rank': result_data.get('rank', 0)
                }
    
    return results

def load_all_search_results(search_folders: dict) -> dict:
    """Load top1 and baseline results from all search folders"""
    all_results = {}
    
    for task, folder in search_folders.items():
        print(f"Loading results from {folder}...")
        try:
            results = load_top_k_results(folder)
            all_results[task] = results
            print(f"  Loaded {len(results)} results for {task}")
        except Exception as e:
            print(f"  Error loading {task}: {e}")
            all_results[task] = {}
    
    return all_results

def classify_heuristic_type(code: str) -> str:
    """Classify the heuristic type based on code content"""
    code_lower = code.lower()
    
    if 'restart' in code_lower:
        return 'restart_condition'
    elif 'bump' in code_lower or 'var' in code_lower:
        return 'bump_var_function'
    elif 'rephase' in code_lower:
        return 'rephase_function'
    else:
        # Default classification based on common patterns
        if 'restart()' in code:
            return 'restart_condition'
        elif 'bump_var' in code:
            return 'bump_var_function'
        elif 'rephase()' in code:
            return 'rephase_function'
        else:
            return 'unknown'

def extract_top_k_and_baseline(all_results: dict, sample_counts: dict = None) -> dict:
    """Extract top-k and baseline results for each function type
    
    Args:
        all_results: Dictionary of results by task
        sample_counts: Dictionary specifying number of samples per task type
                      e.g., {'restart_condition': 3, 'bump_var_function': 2, 'rephase_function': 4}
                      If None, defaults to 2 (baseline + top1) for all tasks
    """
    if sample_counts is None:
        sample_counts = {}
    
    selected_by_type = {}
    
    for task, results in all_results.items():
        # Get sample count for this task (default to 2 if not specified)
        k = sample_counts.get(task, 2)
        
        # Separate baseline and non-baseline results
        baseline_results = []
        non_baseline_results = []
        
        for global_id, result in results.items():
            if result.get('is_baseline', False):
                baseline_results.append((global_id, result))
            else:
                non_baseline_results.append((global_id, result))
        
        # Sort non-baseline results by PAR-2 (lower is better)
        non_baseline_results.sort(key=lambda x: x[1].get('PAR-2', float('inf')))
        
        # Select baseline + top-(k-1) results
        selected_results = []
        
        # Always add baseline if available (takes 1 slot)
        if baseline_results:
            selected_results.append(baseline_results[0])  # Should be only one baseline
            remaining_slots = k - 1
        else:
            remaining_slots = k
        
        # Add top-k results from non-baseline results
        for i in range(min(remaining_slots, len(non_baseline_results))):
            selected_results.append(non_baseline_results[i])
        
        selected_by_type[task] = selected_results
        
        # Print selection summary
        baseline_count = 1 if baseline_results else 0
        topk_count = len(selected_results) - baseline_count
        print(f"  {task}: {len(selected_results)} selected results (baseline={baseline_count}, top-{topk_count}={topk_count}, k={k})")
        
        # Print details
        for global_id, result in selected_results:
            if result.get('is_baseline', False):
                result_type = "baseline"
            else:
                # Find the rank among non-baseline results
                rank = next(i+1 for i, (gid, _) in enumerate(non_baseline_results) if gid == global_id)
                result_type = f"top{rank}"
            print(f"    {result_type} (ID: {global_id}): PAR-2={result.get('PAR-2', 'N/A')}, time={result.get('time', 'N/A')}")
    
    return selected_by_type

def generate_permutations(selected_by_type: dict) -> list:
    """Generate all permutations of one heuristic from each type"""
    permutations = []
    
    # Get all function types
    function_types = list(selected_by_type.keys())
    
    # Generate all combinations
    from itertools import product
    
    # Create lists of results for each type
    type_results = {}
    for func_type, results in selected_by_type.items():
        type_results[func_type] = [result[1]['prompt'] for result in results]
    
    # Generate all permutations
    for combination in product(*type_results.values()):
        permutation = {}
        for i, func_type in enumerate(function_types):
            permutation[func_type] = combination[i]
        permutations.append(permutation)
    
    return permutations

def generate_solver_cpp(permutation: dict, template_path: str, output_path: str):
    """Generate a solver.cpp file from a permutation using Jinja2 template engine"""
    
    # Import Jinja2 components
    from jinja2 import Environment, FileSystemLoader
    
    # Set up Jinja2 environment
    env = Environment(loader=FileSystemLoader('.'))
    template = env.get_template(template_path)
    
    # Prepare template variables - include placeholders for runtime variables
    template_vars = {
        'lbd_queue_size': 50,
        'data_dir': '{{ data_dir }}',  # Keep as placeholder for runtime replacement
        'results_dir': '{{ results_dir }}',  # Keep as placeholder for runtime replacement
        'timeout': '{{ timeout }}',  # Keep as placeholder for runtime replacement
        'random_seed': 42  # Set default random seed
    }
    
    # Add function replacements to template variables
    if 'restart_condition' in permutation:
        template_vars['replace_code'] = permutation['restart_condition']
    
    if 'bump_var_function' in permutation:
        template_vars['bump_var_function'] = permutation['bump_var_function']
    
    if 'rephase_function' in permutation:
        template_vars['rephase_function'] = permutation['rephase_function']
    
    # Debug: print template variables
    print(f"  Template variables: {list(template_vars.keys())}")
    if 'bump_var_function' in template_vars:
        print(f"  bump_var_function length: {len(template_vars['bump_var_function'])}")
        print(f"  bump_var_function value: {template_vars['bump_var_function'][:100]}...")
    
    # Render template
    output = template.render(**template_vars)
    
    # Post-process to preserve timeout and data_dir placeholders
    # Replace any Jinja2-processed empty placeholders back to the original format
    output = output.replace('TIMEOUT_PLACEHOLDER', '{{ timeout }}')
    output = output.replace('DATA_DIR_PLACEHOLDER', '{{ data_dir }}')
    
    # Write the new solver file
    with open(output_path, 'w') as f:
        f.write(output)
    
    # Copy necessary header files to the output directory
    import shutil
    output_dir = os.path.dirname(output_path)
    original_dir = './examples/EasySAT/original_EasySAT'
    for file in ['EasySAT.hpp', 'heap.hpp']:
        src_file = os.path.join(original_dir, file)
        dst_file = os.path.join(output_dir, file)
        if os.path.exists(src_file):
            shutil.copy2(src_file, dst_file)
        else:
            print(f"Warning: {file} not found in {original_dir}")

def main():
    parser = argparse.ArgumentParser(description='Generate heuristic combinations from multiple search results')
    parser.add_argument('--search_folders', nargs='+', default=None,
                       help='Specific search folders to use (if not provided, uses latest for each type)')
    parser.add_argument('--template_path', type=str, 
                       default="./examples/EasySAT/combination/EasySAT.cpp",
                       help='Path to the template EasySAT.cpp file')
    parser.add_argument('--output_dir', type=str, default="./combinations/",
                       help='Output directory for generated combinations')
    parser.add_argument('--tasks', nargs='+', 
                       default=["restart_condition", "bump_var_function", "rephase_function"],
                       help='Function types to combine')
    parser.add_argument('--sample_restart', type=int, default=2,
                       help='Number of restart_condition samples to select (baseline + top-k), default=2')
    parser.add_argument('--sample_bump_var', type=int, default=2,
                       help='Number of bump_var_function samples to select (baseline + top-k), default=2')
    parser.add_argument('--sample_rephase', type=int, default=2,
                       help='Number of rephase_function samples to select (baseline + top-k), default=2')
    
    args = parser.parse_args()
    
    # Find search folders
    if args.search_folders:
        # Use provided folders
        search_folders = {}
        for folder in args.search_folders:
            if not os.path.exists(folder):
                print(f"Error: Search folder not found: {folder}")
                return
            
            # Determine task type from folder name (support all formats)
            # Get basename to handle paths like ./folder_name
            folder_basename = os.path.basename(folder)
            folder_matched = False
            for task in args.tasks:
                if (folder_basename.startswith(f"{task}_search_") or 
                    folder_basename.startswith(f"{task}_session_") or 
                    folder_basename.startswith(f"{task}_results_")):
                    search_folders[task] = folder
                    folder_matched = True
                    print(f"  Matched {folder} to task {task}")
                    break
            
            if not folder_matched:
                print(f"  Warning: Could not match folder {folder} (basename: {folder_basename}) to any task")
        
        if len(search_folders) != len(args.tasks):
            print(f"Error: Expected {len(args.tasks)} search folders, found {len(search_folders)}")
            return
    else:
        # Use latest folders
        search_folders = find_latest_search_folders()
        if len(search_folders) != len(args.tasks):
            print(f"Error: Found {len(search_folders)} search folders, expected {len(args.tasks)}")
            print("Available folders:", list(search_folders.keys()))
            return
    
    # Load all search results
    all_results = load_all_search_results(search_folders)
    
    # Check if we have results for all required tasks
    missing_tasks = [task for task in args.tasks if task not in all_results or not all_results[task]]
    if missing_tasks:
        print(f"Error: Missing results for tasks: {missing_tasks}")
        return
    
    # Prepare sample counts dictionary
    sample_counts = {
        'restart_condition': args.sample_restart,
        'bump_var_function': args.sample_bump_var,
        'rephase_function': args.sample_rephase
    }
    
    # Extract top-k and baseline results by type
    print(f"\nExtracting top-k and baseline results by function type...")
    print(f"Sample counts: restart={args.sample_restart}, bump_var={args.sample_bump_var}, rephase={args.sample_rephase}")
    selected_by_type = extract_top_k_and_baseline(all_results, sample_counts)
    
    # Generate all permutations
    print(f"Generating all permutations...")
    permutations = generate_permutations(selected_by_type)
    print(f"Generated {len(permutations)} combinations")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Generate solver.cpp files for each permutation
    generated_files = []
    for i, permutation in enumerate(permutations):
        output_path = os.path.join(args.output_dir, f"solver_combination_{i}.cpp")
        generate_solver_cpp(permutation, args.template_path, output_path)
        generated_files.append(output_path)
        
        # Print combination details
        print(f"Combination {i}:")
        for func_type, code in permutation.items():
            print(f"  {func_type}: {code[:50]}...")
    
    # Print summary
    print(f"\nCombination Summary:")
    print(f"  Tasks combined: {args.tasks}")
    print(f"  Sample counts: restart={args.sample_restart}, bump_var={args.sample_bump_var}, rephase={args.sample_rephase}")
    total_expected = args.sample_restart * args.sample_bump_var * args.sample_rephase
    print(f"  Expected combinations: {args.sample_restart} × {args.sample_bump_var} × {args.sample_rephase} = {total_expected}")
    print(f"  Actual combinations generated: {len(generated_files)}")
    print(f"  Output directory: {args.output_dir}")
    
    # List generated files
    print("\nGenerated files:")
    for i, file_path in enumerate(generated_files):
        print(f"  {i+1}. {os.path.basename(file_path)}")
    
    print(f"\nAll combinations saved to: {args.output_dir}")

if __name__ == "__main__":
    main() 