
import fastapi_poe as fp
import asyncio
import json
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess
import re
import shutil
import glob
import logging
from datetime import datetime
from collections import defaultdict
import argparse

# LLM API configuration
API_KEY  = "your_key"
BOT_NAME = 'Claude-3.7-Sonnet'

# Configuration parameters
parser = argparse.ArgumentParser(description='LLM-based Scheduling Optimization')
parser.add_argument('--algorithm', type=str, default='sjf', help='Algorithm to optimize')
parser.add_argument('--exec_cap', type=int, default=10, help='Executor capacity')
parser.add_argument('--init_stream_config', type=str, default="1_0", help='Configuration of initial and streaming DAGs (e.g., 1_0, 2_0, 1_1)')
parser.add_argument('--query_size', type=str, default='2g', help='Query size to use')
parser.add_argument('--config_idx', type=int, default=0, help='Index of query configuration to use (0-9)')
parser.add_argument('--iterations', type=int, default=2, help='Number of optimization iterations')
args = parser.parse_args()

# Parse num_init_dags and num_stream_dags from init_stream_config
num_init_dags, num_stream_dags = map(int, args.init_stream_config.split('_'))
args.num_init_dags = num_init_dags
args.num_stream_dags = num_stream_dags

# Create history directory
HISTORY_DIR = f"history/{args.exec_cap}_{args.init_stream_config}_{args.config_idx}/{args.iterations}/{args.algorithm}"
os.makedirs(HISTORY_DIR, exist_ok=True)

# Setup logging
log_file = os.path.join(HISTORY_DIR, "optimization.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()  # Output to both console and file
    ]
)
logger = logging.getLogger(__name__)


query_configurations = [
    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],             # First 10 numbers
    [13, 14, 15, 16, 17, 18, 19, 20, 21, 22],    # Last 10 numbers
    [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],         # Odd numbers
    [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],        # Even numbers
    [1, 4, 7, 10, 13, 16, 19, 22, 2, 5],         # Interval of 3, supplemented at the end
    [3, 6, 9, 12, 15, 18, 21, 1, 8, 11],         # Interval of 3, offset
    [2, 5, 8, 11, 14, 17, 20, 3, 6, 9],          # Interval of 3, another offset
    [1, 5, 9, 13, 17, 21, 2, 6, 10, 14],         # Interval of 4
    [3, 7, 11, 15, 19, 4, 8, 12, 16, 20],        # Interval of 4, offset
    [1, 6, 11, 16, 21, 3, 8, 13, 18, 22]         # Interval of 5
]


# Ensure log directory exists
os.makedirs('log', exist_ok=True)

# System state and observation information description, to be provided to LLM
SYSTEM_STRUCTURE_INFO = """
System state structure information:
1. job_dags: Job Directed Acyclic Graph (DAG) collection
   - name: Job name
   - num_nodes: Total number of nodes
   - num_nodes_done: Number of completed nodes
   - executor_count: Number of executors
   - frontier_nodes: List of schedulable frontier node IDs
   - nodes: Node details dictionary
     - Node ID: {
       - idx: Node ID
       - num_tasks: Total number of tasks
       - num_finished_tasks: Number of completed tasks
       - no_more_tasks: Whether there are no more tasks
       - tasks_all_done: Whether all tasks are completed
       - node_finish_time: Node completion time
       - active_executors: Number of active executors
       - is_schedulable: Whether the node is schedulable
       - node_duration: Node duration
       - is_frontier: Whether it's a frontier node
       - parent_nodes: List of parent node IDs
       - child_nodes: List of child node IDs
       - descendant_nodes: List of all descendant node IDs
     }

2. source_job: Source job name
3. num_source_exec: Number of source executors
4. frontier_nodes: List of all frontier nodes [[node ID, job name], ...]
5. executor_limits: Executor limits for each job {job name: limit count}
6. moving_executors: Executors currently moving [{executor_id, target_node}, ...]
7. system_stats: System statistics
   - total_jobs: Total number of jobs
   - total_nodes: Total number of nodes
   - total_frontier_nodes: Total number of frontier nodes
   - total_executors: Total number of executors
   - total_moving_executors: Number of executors currently moving
"""

# Scheduling algorithm input-output interface specification
ALGORITHM_INTERFACE_INFO = """
Scheduling algorithm interface specification:

1. Input parameters (observation obs):
   - job_dags: Collection of all job DAGs
   - source_job: Source job (usually a job with executors available for reassignment)
   - num_source_exec: Number of available executors in the source job
   - frontier_nodes: Schedulable frontier nodes across all jobs
   - executor_limits: Executor quantity limits for each job
   - exec_commit: Executors already committed to each node
   - moving_executors: Executors currently in transit
   - action_map: Action mapping (currently unused)

2. Output requirements:
   - Return a tuple of node and use_exec flag
   - Node: Should be one of the frontier_nodes, indicating the node to schedule
   - use_exec: Whether to use an executor (typically 1 means use, 0 means don't use)
   
3. Notes:
   - Don't access fields or attributes that don't exist in obs
   - Ensure the algorithm can safely handle various edge cases
   - Don't change any interface signatures or return types
"""

# Asynchronous function for LLM interaction, and recording prompts and replies
async def get_responses(api_key, messages, bot_name):
    # Record the prompt sent to LLM
    prompt_text = "\n".join([msg.content for msg in messages])
    prompt_path = os.path.join(HISTORY_DIR, f"llm_prompt_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
    with open(prompt_path, 'w') as f:
        f.write(prompt_text)
    logger.info(f"Prompt sent to LLM saved to {prompt_path}")
    
    # Call LLM API
    response = ''
    async for partial in fp.get_bot_response(messages=messages, bot_name=bot_name, api_key=api_key):
        response += partial.text
    
    # Record LLM's reply
    response_path = os.path.join(HISTORY_DIR, f"llm_response_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
    with open(response_path, 'w') as f:
        f.write(response)
    logger.info(f"Response from LLM saved to {response_path}")
    
    return response

# Get algorithm file path
def get_algorithm_file(algorithm):
    return f"{algorithm}.py"

# Backup original code
def backup_original_code(algorithm):
    filename = get_algorithm_file(algorithm)
    if not os.path.exists(filename):
        logger.error(f"Algorithm file {filename} does not exist.")
        return False
    
    backup_path = os.path.join(HISTORY_DIR, "original.py")
    try:
        shutil.copy2(filename, backup_path)
        logger.info(f"Original code backed up to {backup_path}")
        return True
    except Exception as e:
        logger.error(f"Error backing up original code: {e}")
        return False

# Restore original code
def restore_original_code(algorithm):
    filename = get_algorithm_file(algorithm)
    backup_path = os.path.join(HISTORY_DIR, "original.py")
    try:
        shutil.copy2(backup_path, filename)
        logger.info(f"Original code restored from {backup_path}")
        return True
    except Exception as e:
        logger.error(f"Error restoring original code: {e}")
        return False

# Save current version of code
def save_iteration_code(algorithm, iteration, stage="llm", notes=""):
    filename = get_algorithm_file(algorithm)
    
    # Create timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(HISTORY_DIR, f"iteration_{iteration}_{stage}_{timestamp}.py")
    
    try:
        shutil.copy2(filename, save_path)
        
        # Save related notes
        notes_path = os.path.join(HISTORY_DIR, f"iteration_{iteration}_{stage}_{timestamp}.txt")
        with open(notes_path, 'w') as f:
            f.write(notes)
        
        logger.info(f"Code for iteration {iteration} ({stage}) saved to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Error saving iteration code: {e}")
        return False

# Run test_group.py and get reward
def run_test_group_script(algorithm):
    # Build a unique identifier to ensure log files won't conflict
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    
    cmd = f"python test_group2.py --config_idx {args.config_idx} --test_schemes {algorithm} --exec_cap {args.exec_cap} --init_stream_config  {args.num_init_dags}_{args.num_stream_dags} --num_exp 1 --query_size {args.query_size}"
    logger.info(f"Running test: {cmd}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    
    # Extract average reward from output
    output = result.stdout
    reward_pattern = rf"{algorithm} average reward: ([-\d.]+)"
    match = re.search(reward_pattern, output)
    
    if match:
        return float(match.group(1))
    else:
        logger.warning("Could not extract reward from output")
        logger.warning(f"Output: {output}")
        return None

# Improved log parsing function, capable of handling non-standard JSON
def parse_log_file(filename):
    """
    Improved log parsing function, specifically for handling non-standard JSON
    """
    try:
        # Read entire file content
        with open(filename, 'r') as f:
            content = f.read()
        
        # Preprocessing: remove all newlines and extra spaces
        content = content.replace('\n', '')
        content = re.sub(r'\s+', ' ', content)
        
        # Find all JSON objects
        objects = []
        open_braces = 0
        start_pos = 0
        
        for i, char in enumerate(content):
            if char == '{':
                if open_braces == 0:
                    start_pos = i
                open_braces += 1
            elif char == '}':
                open_braces -= 1
                if open_braces == 0:
                    # Extract complete JSON object
                    obj_str = content[start_pos:i+1]
                    try:
                        obj = json.loads(obj_str)
                        objects.append(obj)
                    except json.JSONDecodeError:
                        logger.debug(f"Unable to parse JSON object: {obj_str[:100]}...")
        
        # Process objects to create sequences
        sequences = []
        current_seq = []
        pending_action = None
        
        for obj in objects:
            # Handle independent action entries
            if 'action' in obj and 'obs' not in obj:
                if isinstance(obj['action'], list):
                    pending_action = {
                        'action': {
                            'node': obj['action'],
                            'use_exec': obj.get('use_exec', 0)
                        }
                    }
                else:
                    pending_action = obj
            elif 'obs' in obj:
                # If there's a pending_action, need to merge with current obs
                if pending_action:
                    complete_step = {
                        'action': pending_action.get('action', {}),
                        'obs': obj.get('obs', {}),
                        'reward': obj.get('reward', 0.0)
                    }
                    
                    # Save use_exec
                    if 'use_exec' in pending_action:
                        complete_step['use_exec'] = pending_action['use_exec']
                    
                    current_seq.append(complete_step)
                    pending_action = None
                elif 'action' in obj:
                    # Complete step
                    current_seq.append(obj)
                
                # Detect start of new sequence (new job arrives)
                job_dags = obj.get('obs', {}).get('job_dags', {})
                if job_dags and len(job_dags) > 0 and current_seq and len(current_seq) > 1:
                    sequences.append(current_seq)
                    current_seq = []
                    if 'action' in obj:
                        current_seq.append(obj)
        
        # Add the last sequence
        if current_seq:
            sequences.append(current_seq)
        
        return sequences
        
    except Exception as e:
        logger.error(f"Error processing log file {filename}: {e}")
        return []

def calculate_cumulative_reward(sequence):
    """Calculate cumulative reward for a sequence"""
    reward_sum = 0
    for entry in sequence:
        if isinstance(entry, dict) and 'reward' in entry and isinstance(entry['reward'], (int, float)):
            reward_sum += entry['reward']
    return reward_sum

# Collect all log files instead of just finding the largest gap log
def collect_log_files():
    result_folder = 'log/'
    
    # Get queries for current configuration
    config = query_configurations[args.config_idx]
    
    algorithm_logs = []
    learn_logs = []
    
    for query_idx in config:
        # Build log file paths
        algorithm_log = f"{result_folder}scheme_{args.algorithm}_{args.exec_cap}_{args.num_init_dags}_{args.num_stream_dags}_{args.query_size}_{query_idx}_{args.query_size}_{query_idx}.log"
        learn_log =  'rl_path' 
        
        # Check if files exist
        if os.path.exists(algorithm_log):
            algorithm_logs.append(algorithm_log)
        else:
            logger.warning(f"Could not find algorithm log file: {algorithm_log}")
            
        if os.path.exists(learn_log):
            learn_logs.append(learn_log)
        else:
            logger.warning(f"Could not find learn log file: {learn_log}")
    
    logger.info(f"Found {len(algorithm_logs)} algorithm logs and {len(learn_logs)} learn logs")
    return algorithm_logs, learn_logs

# Analyze logs and extract cases
def extract_cases_from_logs(algorithm_logs, learn_logs):
    all_cases = []
    
    # Iterate through all log file pairs
    for algo_log, learn_log in zip(algorithm_logs, learn_logs):
        try:
            # Parse log files
            algorithm_sequences = parse_log_file(algo_log)
            learn_sequences = parse_log_file(learn_log)
            
            if not algorithm_sequences or not learn_sequences:
                logger.warning(f"Skipping empty log files: {algo_log} or {learn_log}")
                continue
            
            # Extract query information
            query_match = re.search(r'(\d+g)_(\d+)', algo_log)
            query_info = f"query {query_match.group(2)} (size {query_match.group(1)})" if query_match else f"unknown query"
            
            # Calculate cumulative reward for each sequence
            algo_rewards = [calculate_cumulative_reward(seq) for seq in algorithm_sequences]
            learn_rewards = [calculate_cumulative_reward(seq) for seq in learn_sequences]
            
            # Compare cumulative reward gaps between corresponding sequences
            min_length = min(len(algo_rewards), len(learn_rewards))
            
            for i in range(min_length):
                gap = learn_rewards[i] - algo_rewards[i]
                logger.info(f"{query_info} - Sequence {i}: gap {gap:.4f}, learn {learn_rewards[i]:.4f}, algo {algo_rewards[i]:.4f}")
                
                # Collect all cases, no longer focusing only on positive gaps
                all_cases.append((algo_log, learn_log, i, gap, algorithm_sequences[i], learn_sequences[i], query_info))
                
        except Exception as e:
            logger.error(f"Error processing logs {algo_log} and {learn_log}: {e}")
    
    return all_cases

# Extract code from LLM suggestions and apply to algorithm file
async def apply_llm_suggestions(algorithm, advice):
    """Extract code from LLM suggestions and apply to algorithm file"""
    algorithm_file = get_algorithm_file(algorithm)
    if not os.path.exists(algorithm_file):
        logger.error(f"Could not find file for algorithm {algorithm}")
        return False
    
    # Read current code
    with open(algorithm_file, 'r') as f:
        current_code = f.read()
    
    # Extract code blocks from LLM advice
    code_blocks = re.findall(r'```python(.*?)```', advice, re.DOTALL)
    if not code_blocks:
        logger.warning("No code blocks found in LLM advice")
        return False
    
    # Build prompt to ask LLM how to integrate suggestions
    integration_prompt = f"""
I have an algorithm file and some suggested code changes. Please help me integrate these changes properly.

Current algorithm file:
```python
{current_code}
```

Suggested code changes:
{advice}

{SYSTEM_STRUCTURE_INFO}

{ALGORITHM_INTERFACE_INFO}

IMPORTANT REQUIREMENTS:
1. The input and output interfaces of the algorithm MUST remain EXACTLY the same.
2. Do not change any function signatures or return values.
3. Ensure the code maintains compatibility with the existing system.
4. Focus only on improving the algorithm's internal logic while preserving its integration capabilities.
5. All original functionality must be maintained.
6. Always use safe access patterns for observation (obs) data:
   - Use hasattr() to check if an attribute exists before accessing it
   - Use get() method for dictionary access with default values
   - Use try/except blocks to handle potential errors
   - Never assume a field exists in the observation
7. Make sure the algorithm handles edge cases gracefully.

Please provide the complete updated file with all changes integrated. Return ONLY the complete Python code file with no explanations.
"""
    
    message = fp.ProtocolMessage(role="user", content=integration_prompt)
    try:
        # Get LLM's complete integration suggestions
        updated_code = await get_responses(API_KEY, [message], BOT_NAME)
        
        # Try to extract code from response
        code_match = re.search(r'```python(.*?)```', updated_code, re.DOTALL)
        if code_match:
            updated_code = code_match.group(1).strip()
        else:
            # If no code block format, assume the entire response is code
            # But first check if it looks like Python code
            if not updated_code.strip().startswith('import') and 'class' not in updated_code:
                logger.warning("LLM response doesn't appear to be valid Python code")
                logger.debug(f"Response: {updated_code[:500]}...")
                return False
        
        # Verify if updated code includes safe access patterns
        safer_code_indicators = [
            'hasattr(' in updated_code,
            '.get(' in updated_code or 'dict.get(' in updated_code,
            'try:' in updated_code and 'except' in updated_code
        ]
        
        if not any(safer_code_indicators):
            logger.warning("Updated code may not include safe access patterns")
            # Ask LLM to fix code again, especially emphasizing safe access
            safety_prompt = f"""
The code you provided doesn't include sufficient safety checks for accessing observation data. 
Please modify the code to include:
1. hasattr() checks before accessing object attributes
2. Dictionary .get() method with default values
3. try/except blocks for error handling

Here's your current code:
```python
{updated_code}
```

Please update it with proper safety checks and return ONLY the complete Python code.
"""
            message = fp.ProtocolMessage(role="user", content=safety_prompt)
            updated_code = await get_responses(API_KEY, [message], BOT_NAME)
            
            # Try to extract code block again
            code_match = re.search(r'```python(.*?)```', updated_code, re.DOTALL)
            if code_match:
                updated_code = code_match.group(1).strip()
        
        # Apply updated code
        with open(algorithm_file, 'w') as f:
            f.write(updated_code)
        
        logger.info(f"Applied LLM suggestions to {algorithm_file}")
        return True
    except Exception as e:
        logger.error(f"Error applying LLM suggestions: {e}")
        return False

# Request optimization advice from LLM, using all collected cases
async def get_optimization_advice(algorithm, cases, failed_optimizations=None):
    prompt = f"""
I'm optimizing a {algorithm} scheduling algorithm. I've compared it with a learning-based algorithm across multiple test cases and found these situations:

"""
    # Add information about past failed optimization attempts
    if failed_optimizations and len(failed_optimizations) > 0:
        prompt += f"\n### Previous Optimization Attempts That Did Not Work ###\n"
        for i, (optimization, result) in enumerate(failed_optimizations):
            prompt += f"\nAttempt {i+1}:\n"
            prompt += f"Optimization tried: {optimization}\n"
            prompt += f"Result: {result}\n"
            prompt += f"---\n"
    
    # Select at most 3 cases to avoid prompt being too long
    selected_cases = cases[:3] if len(cases) > 3 else cases
    
    for i, (algo_log, learn_log, index, gap, algo_sequence, learn_sequence, query_info) in enumerate(selected_cases):
        prompt += f"\nCase {i+1} - {query_info} (cumulative reward difference: {gap:.4f}):\n"
        
        # Compare key events for each sequence
        algo_cumulative = 0
        learn_cumulative = 0
        
        # Show at most 10 decision points for comparison
        max_steps = min(10, min(len(algo_sequence), len(learn_sequence)))
        
        for j in range(max_steps):
            if j >= len(algo_sequence) or j >= len(learn_sequence):
                break
                
            algo_entry = algo_sequence[j]
            learn_entry = learn_sequence[j]
            
            # Ensure entries are dictionaries
            if not isinstance(algo_entry, dict) or not isinstance(learn_entry, dict):
                continue
                
            # Add current observation state (if available)
            if 'obs' in algo_entry and isinstance(algo_entry['obs'], dict):
                # Only extract important observation info to reduce prompt size
                important_obs = {}
                if 'frontier_nodes' in algo_entry['obs']:
                    important_obs['frontier_nodes'] = algo_entry['obs']['frontier_nodes']
                if 'system_stats' in algo_entry['obs']:
                    important_obs['system_stats'] = algo_entry['obs']['system_stats']
                if 'job_dags' in algo_entry['obs']:
                    # Simplify job_dags information
                    important_obs['job_dags'] = {}
                    for job_name, job_info in algo_entry['obs']['job_dags'].items():
                        important_obs['job_dags'][job_name] = {
                            'name': job_info.get('name', ''),
                            'num_nodes': job_info.get('num_nodes', 0),
                            'executor_count': job_info.get('executor_count', 0),
                            'frontier_nodes': job_info.get('frontier_nodes', [])
                        }
                        # If there's node information, extract key info
                        if 'nodes' in job_info:
                            important_obs['job_dags'][job_name]['nodes'] = {}
                            for node_id, node_info in job_info['nodes'].items():
                                important_obs['job_dags'][job_name]['nodes'][node_id] = {
                                    'num_tasks': node_info.get('num_tasks', 0),
                                    'node_duration': node_info.get('node_duration', 0),
                                    'is_frontier': node_info.get('is_frontier', False)
                                }
                
                prompt += f"\nStep {j+1} - Current state (observation):\n"
                prompt += f"{json.dumps(important_obs, indent=2)}\n"
            
            # Add actions and rewards for both algorithms
            if 'action' in algo_entry:
                algo_reward = algo_entry.get('reward', 0)
                algo_cumulative += algo_reward
                
                # Handle different action formats
                action_info = algo_entry['action']
                if isinstance(action_info, dict):
                    prompt += f"\nMy algorithm's action:\n{json.dumps(action_info, indent=2)}\n"
                elif isinstance(action_info, list):
                    prompt += f"\nMy algorithm's action:\nnode: {action_info}\nuse_exec: {algo_entry.get('use_exec', 'unknown')}\n"
                else:
                    prompt += f"\nMy algorithm's action: {action_info}\n"
                    
                prompt += f"Reward: {algo_reward} (Cumulative: {algo_cumulative})\n"
            
            if 'action' in learn_entry:
                learn_reward = learn_entry.get('reward', 0)
                learn_cumulative += learn_reward
                
                # Handle different action formats
                action_info = learn_entry['action']
                if isinstance(action_info, dict):
                    prompt += f"\nLearning algorithm's action:\n{json.dumps(action_info, indent=2)}\n"
                elif isinstance(action_info, list):
                    prompt += f"\nLearning algorithm's action:\nnode: {action_info}\nuse_exec: {learn_entry.get('use_exec', 'unknown')}\n"
                else:
                    prompt += f"\nLearning algorithm's action: {action_info}\n"
                    
                prompt += f"Reward: {learn_reward} (Cumulative: {learn_cumulative})\n"
            
            prompt += "\n" + "-" * 40 + "\n"
    
    prompt += f"""
{SYSTEM_STRUCTURE_INFO}

{ALGORITHM_INTERFACE_INFO}

Based on these examples across multiple queries, please provide:
1. Specific suggestions to improve my {algorithm} scheduling algorithm
2. The root causes of the performance gap
3. Can you suggest any code changes to improve performance?

IMPORTANT SAFETY REQUIREMENTS:
1. Your suggestions must maintain complete compatibility with the existing codebase.
2. The input and output formats of the algorithm must remain exactly the same.
3. Any changes should not break algorithm integration with the existing system.
4. Always use safe access patterns for observation data:
   - Check if attributes exist with hasattr() before accessing them
   - Use dict.get() with default values for dictionary access
   - Use try/except blocks to handle potential errors
   - Never assume a field or attribute exists in the observation
5. Make sure the algorithm handles edge cases gracefully.

Please be concrete and actionable in your advice. Provide specific code snippets that I can implement.
"""
    
    message = fp.ProtocolMessage(role="user", content=prompt)
    try:
        response = await get_responses(API_KEY, [message], BOT_NAME)
        return response
    except Exception as e:
        logger.error(f"Error getting LLM response: {e}")
        return f"Error getting LLM response: {str(e)}"

# Main optimization loop
async def optimize_algorithm():
    algorithm = args.algorithm
    best_reward = None
    best_code_path = None
    failed_optimizations = []  # For storing failed optimization attempts
    
    logger.info(f"Starting optimization for {algorithm} algorithm using configuration {args.config_idx}")
    logger.info(f"Detailed configuration: exec_cap={args.exec_cap}, init_dags={args.num_init_dags}, stream_dags={args.num_stream_dags}, query_size={args.query_size}")
    
    # First backup original code
    if not backup_original_code(algorithm):
        logger.error("Failed to backup original code. Aborting.")
        return
    
    try:
        # Step 1: Run original algorithm to get baseline reward
        logger.info("Step 1: Running original algorithm to get baseline reward")
        baseline_reward = run_test_group_script(algorithm)
        if baseline_reward is None:
            logger.error("Could not get baseline reward. Aborting.")
            return
        
        logger.info(f"Baseline reward: {baseline_reward}")
        best_reward = baseline_reward
        
        # Save baseline version
        baseline_code_path = os.path.join(HISTORY_DIR, "baseline_version.py")
        shutil.copy2(get_algorithm_file(algorithm), baseline_code_path)
        best_code_path = baseline_code_path
        
        # Step 2: Perform LLM optimization iterations
        logger.info(f"Step 2: Starting LLM optimization iterations")
        
        for iteration in range(args.iterations):
            logger.info(f"\n=== LLM Iteration {iteration + 1}/{args.iterations} ===")
            
            # Collect all related log files
            algorithm_logs, learn_logs = collect_log_files()
            
            if not algorithm_logs or not learn_logs:
                logger.error("Cannot proceed without valid log files.")
                continue
            
            # Extract all cases from logs
            all_cases = extract_cases_from_logs(algorithm_logs, learn_logs)
            
            if not all_cases:
                logger.info("No cases found in logs.")
                continue
            
            # Get LLM optimization advice, including previous failed attempts
            logger.info("Requesting optimization advice from LLM...")
            advice = await get_optimization_advice(algorithm, all_cases, failed_optimizations)
            logger.info("LLM advice received")
            
            # Save LLM advice
            advice_path = os.path.join(HISTORY_DIR, f"iteration_{iteration+1}_llm_advice.txt")
            with open(advice_path, 'w') as f:
                f.write(advice)
            
            # Save code before LLM changes
            save_iteration_code(algorithm, iteration+1, "pre_llm", notes="Pre-LLM version")
            
            # Apply LLM suggestions to code
            logger.info("Applying LLM suggestions to code...")
            success = await apply_llm_suggestions(algorithm, advice)
            if not success:
                logger.warning("Failed to apply LLM suggestions. Continuing with original code.")
                if best_code_path and os.path.exists(best_code_path):
                    shutil.copy2(best_code_path, get_algorithm_file(algorithm))
                
                # Record failed attempt
                failed_optimizations.append((
                    f"LLM code suggestion failed to apply (iteration {iteration+1})",
                    "Could not extract valid code from LLM response"
                ))
                
                continue
            
            # Save code after LLM changes
            save_iteration_code(algorithm, iteration+1, "post_llm", notes=advice)
            
            # Reward after LLM suggestions
            logger.info("Running test with optimized algorithm...")
            llm_reward = run_test_group_script(algorithm)
            if llm_reward is None:
                logger.warning("Could not get reward after LLM suggestions.")
                
                # Record failed attempt
                failed_optimizations.append((
                    f"LLM code optimization attempt (iteration {iteration+1}) could not get results",
                    "Test run did not return valid reward"
                ))
                
                continue
                
            logger.info(f"Reward after LLM iteration {iteration+1}: {llm_reward}")
            
            # Detailed recording of optimization results
            optimization_result = (
                f"Iteration {iteration+1} results:\n"
                f"Baseline reward: {baseline_reward}\n"
                f"Optimized reward: {llm_reward}\n"
                f"Improvement: {(llm_reward - baseline_reward) / abs(baseline_reward) * 100:.2f}%\n"
                f"Optimization suggestions: {advice_path}"
            )
            
            result_path = os.path.join(HISTORY_DIR, f"iteration_{iteration+1}_results.txt")
            with open(result_path, 'w') as f:
                f.write(optimization_result)
            
            # Update best version record
            if llm_reward > best_reward:
                best_reward = llm_reward
                best_code_path = os.path.join(HISTORY_DIR, f"best_version_iter_{iteration+1}.py")
                shutil.copy2(get_algorithm_file(algorithm), best_code_path)
                logger.info(f"New best version saved with reward: {best_reward}")
                
                # Record successful optimization
                success_details = (
                    f"LLM optimization successful (iteration {iteration+1})\n"
                    f"Improvement: {(llm_reward - baseline_reward) / abs(baseline_reward) * 100:.2f}%"
                )
                
                success_path = os.path.join(HISTORY_DIR, f"success_details_iter_{iteration+1}.txt")
                with open(success_path, 'w') as f:
                    f.write(success_details)
                    
            else:
                logger.info(f"No improvement in this iteration. Best reward remains: {best_reward}")
                # If no improvement, rollback to best version
                if best_code_path and os.path.exists(best_code_path):
                    shutil.copy2(best_code_path, get_algorithm_file(algorithm))
                    logger.info("Reverted to previous best version")
                
                # Record failed optimization attempt
                failed_optimizations.append((
                    f"LLM optimization attempt (iteration {iteration+1})",
                    f"Post-optimization reward {llm_reward} lower than current best reward {best_reward}"
                ))
        
        logger.info(f"\n=== Optimization complete ===")
        logger.info(f"Initial baseline reward: {baseline_reward}")
        logger.info(f"Final best reward: {best_reward}")
        logger.info(f"Overall improvement: {(best_reward - baseline_reward) / abs(baseline_reward) * 100:.2f}%")
        
        # Save final best version
        if best_code_path and os.path.exists(best_code_path):
            final_best_path = os.path.join(HISTORY_DIR, "final_best_version.py")
            shutil.copy2(best_code_path, final_best_path)
            logger.info(f"Best version saved to {final_best_path}")
        
        # Save optimization results summary
        with open(os.path.join(HISTORY_DIR, "optimization_summary.txt"), 'w') as f:
            f.write(f"Optimization Summary for {algorithm}\n")
            f.write(f"Configuration: {args.exec_cap}_{args.init_stream_config}_{args.config_idx}\n\n")
            f.write(f"Initial baseline reward: {baseline_reward}\n")
            f.write(f"Final best reward: {best_reward}\n")
            f.write(f"Overall improvement: {(best_reward - baseline_reward) / abs(baseline_reward) * 100:.2f}%\n\n")
            
            # Record all failed attempts
            if failed_optimizations:
                f.write("\nFailed Optimization Attempts:\n")
                for i, (optimization, result) in enumerate(failed_optimizations):
                    f.write(f"\n{i+1}. {optimization}\n")
                    f.write(f"   Result: {result}\n")
    
    except Exception as e:
        logger.error(f"Error during optimization: {e}")
        import traceback
        logger.error(traceback.format_exc())
    
    finally:
        # Whether a better algorithm is found or not, restore to original algorithm for future optimization runs
        restore_original_code(algorithm)
        logger.info("Original code restored for future optimization runs.")

# Main function
if __name__ == "__main__":
    # Print configuration information
    logger.info(f"Starting optimization with configuration:")
    logger.info(f"Algorithm: {args.algorithm}")
    logger.info(f"Executor capacity: {args.exec_cap}")
    logger.info(f"Initial DAGs: {args.num_init_dags}, Streaming DAGs: {args.num_stream_dags}")
    logger.info(f"Config index: {args.config_idx}")
    logger.info(f"Query size: {args.query_size}")
    logger.info(f"LLM iterations: {args.iterations}")
    
    # Run optimization
    asyncio.run(optimize_algorithm())

    logger.info("Optimization process completed.")
