# This script is used to optimize bitrate adaptive algorithms, obtaining optimization suggestions through interactions with LLM model and automatically adjusting parameters through Bayesian optimization
# Workflow:
# 1. Run test script and get initial reward
# 2. Calculate and save badcases
# 3. Get optimization suggestions and reasons from LLM
# 4. Apply optimization suggestions to the algorithm
# 5. If Bayesian optimization count > 0, automatically use Bayesian optimization
# 6. Run test script and get new reward
# 7. Loop iteration until optimal results are achieved




import asyncio
import fastapi_poe as fp
import subprocess
import json
import re
import os
import importlib
import numpy as np
import glob
import sys
import ast
import time
import random
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
from skopt.utils import use_named_args
from skopt.plots import plot_convergence
import matplotlib.pyplot as plt
import logging
from datetime import datetime
import matplotlib
import shutil
import importlib.util

# Set matplotlib to use Agg backend to avoid Chinese font issues
matplotlib.use('Agg')
# Set simple English font
plt.rcParams['font.family'] = 'DejaVu Sans'

# Create custom JSON encoder to handle NumPy types
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

# Define global variables
ALGORITHM_NAME = "hyb"  # Default algorithm name
MAX_BAYESIAN_ITERATIONS = 500  # Maximum iterations for Bayesian optimization
MAX_LLM_ITERATIONS = 5  # Maximum iterations for LLM optimization
dataset = 'hsr'  # Default dataset
results_dir = './results'
dataset_dir = os.path.join(results_dir, dataset)

# Network bandwidth basic information
NETWORK_INFO = {
    "3g": "1.52 ± 0.72 [0.60, 4.59]Mbps ",
    "oboe": "2.77 ± 1.32 [0.34, 5.70]Mbps ",
    "puffer-2110": "1.60 ± 0.88 [0.30, 3.60]Mbps ",
    "fcc": "1.33 ± 0.55 [0.19, 3.43]Mbps"}

def setup_logging(log_dir):
    """Setup logging"""
    os.makedirs(log_dir, exist_ok=True)
    
    log_file = os.path.join(log_dir, f"optimization_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger("ABR_Optimizer")

# Create async function to wrap async loop to get LLM response
async def get_responses(api_key, messages, bot_name):
    response = ''
    async for partial in fp.get_bot_response(messages=messages, bot_name=bot_name, api_key=api_key):
        response += partial.text
    return response

def extract_code_from_response(response):
    """Extract Python code from LLM response"""
    # Look for Python code blocks
    matches = re.findall(r'```(?:python)?\s*([\s\S]*?)\s*```', response)
    if matches:
        return matches[0]
    
    # If no code blocks, try to find parts starting with def abr
    matches = re.findall(r'(def abr[\s\S]*)', response)
    if matches:
        return matches[0]
    
    # If nothing found, return original response
    return response

def extract_reasoning_from_response(response):
    """Extract reasoning from LLM response"""
    # Try to extract text not in code blocks as reasoning
    code_blocks = re.findall(r'```(?:python)?\s*[\s\S]*?\s*```', response)
    reasoning = response
    
    # Remove all code blocks from complete response
    for block in code_blocks:
        reasoning = reasoning.replace(block, '')
    
    return reasoning.strip()

def calculate_badcases_from_results(process_log_dir, temp_abr_file, process_id):
    """Compare results folders for the same dataset trace, calculate badcases and save top 5"""
    global ALGORITHM_NAME, logger
    logger.info(f"Current algorithm name: {ALGORITHM_NAME}")
    
    # Ensure results directory exists
    if not os.path.exists(dataset_dir):
        logger.error(f"Dataset directory {dataset_dir} does not exist")
        return False

    global_worst_lines = []  # List of tuples: (difference, dataset name, trace name, current algorithm content, rl algorithm content)
    
    # Modified: Get all current algorithm log files with process ID
    current_logs = [f for f in os.listdir(dataset_dir) if f.startswith(f'log_sim_{ALGORITHM_NAME}_{process_id}_')]
    
    for current_log in current_logs:
        # Find corresponding rl log file
        # Modified: Extract trace name from process ID
        trace_name = current_log.replace(f'log_sim_{ALGORITHM_NAME}_{process_id}_', '')
        rl_log = f'log_sim_rl_{trace_name}'
        
        # Check if rl log file exists
        rl_log_path = os.path.join(rl_log_dir, rl_log)
        current_log_path = os.path.join(dataset_dir, current_log)
  
        
        try:
            # Read the two files
            with open(current_log_path, 'r') as f:
                current_lines = f.readlines()
            
            with open(rl_log_path, 'r') as f:
                rl_lines = f.readlines()
            
            # Filter empty lines
            current_lines = [line for line in current_lines if line.strip()]
            rl_lines = [line for line in rl_lines if line.strip()]
            
            # Calculate average reward for each file
            current_rewards = []
            rl_rewards = []
            
            for line in current_lines:
                if line.strip() and not line.startswith('#'):
                    parts = line.split('\t')
                    if len(parts) >= 7:  # Ensure line has enough parts
                        try:
                            reward = float(parts[-1])
                            current_rewards.append(reward)
                        except (ValueError, IndexError):
                            continue
            
            for line in rl_lines:
                if line.strip() and not line.startswith('#'):
                    parts = line.split('\t')
                    if len(parts) >= 7:  # Ensure line has enough parts
                        try:
                            reward = float(parts[-1])
                            rl_rewards.append(reward)
                        except (ValueError, IndexError):
                            continue
            
            # Ensure there are at least some rewards calculated
            if current_rewards and rl_rewards:
                current_avg = np.mean(current_rewards[1:])
                rl_avg = np.mean(rl_rewards[1:])
                
                if current_avg < rl_avg:
                    # Calculate reward difference for each line
                    for i in range(1, min(len(current_lines), len(rl_lines))):
                        if i < len(current_lines) and i < len(rl_lines) and current_lines[i].strip() and rl_lines[i].strip():
                            try:
                                current_parts = current_lines[i].split('\t')
                                rl_parts = rl_lines[i].split('\t')
                                
                                if len(current_parts) >= 7 and len(rl_parts) >= 7:
                                    current_reward = float(current_parts[-1])
                                    rl_reward = float(rl_parts[-1])
                                    diff = rl_reward - current_reward
                                    
                                    if diff > 0:  # Only focus on lines where rl performs better
                                        # Get content of current line and its previous 4 lines
                                        start_idx = max(1, i-4)
                                        current_context_lines = []
                                        rl_context_lines = []
                                        
                                        # Get context for current algorithm
                                        for j in range(start_idx, i+1):
                                            if j < len(current_lines) and current_lines[j].strip():
                                                current_context_lines.append(current_lines[j])
                                        
                                        # Get context for rl algorithm
                                        for j in range(start_idx, i+1):
                                            if j < len(rl_lines) and rl_lines[j].strip():
                                                rl_context_lines.append(rl_lines[j])
                                        
                                        # Add to global list
                                        global_worst_lines.append((diff, trace_name, current_context_lines, rl_context_lines))
                            except (ValueError, IndexError) as e:
                                logger.warning(f"Error processing line {i}: {e}")
                                continue
        except Exception as e:
            logger.error(f"Error processing files {current_log}/{rl_log}: {e}")

    # Process global top 5 lines with largest differences
    global_worst_lines.sort(key=lambda x: x[0], reverse=True)
    top_5_worst = global_worst_lines[:5]
    
    # Save results
    if top_5_worst:
        top5_log_path = os.path.join(process_log_dir, 'badcase.txt')
        with open(top5_log_path, 'w') as f:
            # First write the column meanings and units
            column_meanings = [
                "time(ms)", 
                "bit_rate(Kbps)", 
                "buffer_size(s)", 
                "rebuf(s)", 
                "video_chunk_size(byte)", 
                "delay(ms)", 
                "reward"
            ]
            f.write("# " + "\t".join(column_meanings) + "\n\n")
            
            # Then write the top 5 worst scenarios
            for i, (diff, trace_name, current_lines, rl_lines) in enumerate(top_5_worst):
                f.write(f"=== Top {i+1}: , Trace {trace_name}, Diff: {diff:.6f} ===\n")
                
                # Write current algorithm results
                f.write(f"--- {ALGORITHM_NAME} Algorithm Results ---\n")
                for line in current_lines:
                    f.write(line)
                
                # Write rl algorithm results
                f.write("\n--- rl Algorithm Results ---\n")
                for line in rl_lines:
                    f.write(line)
                
                f.write("\n\n")
            
        logger.info(f"Saved top 5 largest difference badcases to {top5_log_path}")
        return True
    
    logger.info("No badcases found")
    return False

def get_abr_function_content(file_path):
    """Extract complete content of abr function from file"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Find function starting with def abr until end of file
        pattern = r'(def abr\([^)]*\):[\s\S]*)'
        match = re.search(pattern, content)
        if match:
            return match.group(1)
        else:
            logger.error(f"Cannot find abr function in {file_path}")
            return None
    except Exception as e:
        logger.error(f"Error reading {file_path}: {e}")
        return None

def update_temp_abr_function(file_path, new_code):
    """
    Update abr function in temporary file, keeping constant definitions in file header
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # For new file format, extract constant definition part
        constant_pattern = r'(VIDEO_BIT_RATE\s*=.*\nchunk_length\s*=.*\nA_DIM\s*=.*\n)'
        const_match = re.search(constant_pattern, content)
        
        if const_match:
            constants = const_match.group(1)
            # After extracting constants, the rest is function definition
            updated_content = constants + new_code
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(updated_content)
            return True
        else:
            # If no constant definitions found, check if function definition can be matched
            pattern = r'def abr\([^)]*\):[\s\S]*'
            
            match = re.search(pattern, content)
            if match:
                # Replace matched part, keep other content
                start, end = match.span()
                updated_content = content[:start] + new_code
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(updated_content)
                return True
            
            logger.error(f"Cannot find abr function in {file_path}")
            return False
    except Exception as e:
        logger.error(f"Error updating temporary file: {e}")
        return False

def cal_current_reward(process_id):
    """Calculate average reward for current algorithm"""
    try:
        curr_reward = []
        # Modified: Get log files with process ID
        current_logs = [f for f in os.listdir(dataset_dir) if f.startswith(f'log_sim_{ALGORITHM_NAME}_{process_id}_')]
        
        if not current_logs:
            logger.error(f"Cannot find log files for algorithm {ALGORITHM_NAME}, process ID: {process_id}")
            return None
        
        for log in current_logs:
            this_r = []
            try:
                with open(os.path.join(dataset_dir, log), 'r') as f:
                    lines = f.readlines()
                
                for line in lines:
                    if line.strip() and not line.startswith('#'):
                        try:
                            parts = line.split('\t')
                            if len(parts) >= 7:
                                reward = float(parts[-1])
                                this_r.append(reward)
                        except (ValueError, IndexError):
                            continue
                
                if this_r:
                    curr_reward.append(np.mean(this_r[1:]))
            except Exception as e:
                logger.error(f"Error processing log file {log}: {e}")
        
        if curr_reward:
            return np.mean(curr_reward)
        else:
            logger.error("No valid reward values found")
            return None
    except Exception as e:
        logger.error(f"Error calculating current reward: {e}")
        return None

async def get_optimization_parameters(api_key, bot_name, code):
    """Use LLM to identify optimizable parameters and search ranges"""
    prompt = f"""Please analyze the following ABR (Adaptive Bitrate) algorithm code and identify parameters that can be optimized using Bayesian optimization. For each parameter, provide its search range.
    
    Code:
    ```python
    {code}
    ```
    
    Please return the results in JSON format as follows:
    ```json
    {{
        "parameters": [
            {{
                "name": "parameter_name",
                "type": "integer|float",
                "min": minimum_value,
                "max": maximum_value,
                "current": current_value,
                "description": "parameter description"
            }},
            ...
        ]
    }}
    ```
    
    Only return results in JSON format, don't add any other explanations. Ensure the JSON format is correct and can be parsed.
    """
    
    message = fp.ProtocolMessage(role="user", content=prompt)
    response = await get_responses(api_key, [message], bot_name)
    
    # Extract JSON part
    json_pattern = r'```json\s*([\s\S]*?)\s*```'
    json_match = re.search(json_pattern, response)
    
    if json_match:
        json_str = json_match.group(1)
        try:
            params_data = json.loads(json_str)
            return params_data
        except json.JSONDecodeError:
            logger.error("Cannot parse JSON returned by LLM")
    else:
        # Try to parse entire response directly
        try:
            params_data = json.loads(response)
            return params_data
        except json.JSONDecodeError:
            logger.error("Cannot extract parameter information from LLM response")
    
    return {"parameters": []}

def create_search_space_from_llm_params(params_data):
    """Create Bayesian optimization search space based on parameters provided by LLM"""
    dimensions = []
    param_names = []
    param_details = {}
    
    for param in params_data.get("parameters", []):
        name = param.get("name")
        param_type = param.get("type", "").lower()
        
        if not name:
            continue
            
        min_val = param.get("min")
        max_val = param.get("max")
        current = param.get("current")
        
        # Check and handle non-standard format range values
        try:
            # If min_val or max_val is a list, try to take first element
            if isinstance(min_val, list):
                logger.warning(f"Parameter {name} min value is a list: {min_val}, taking first element")
                min_val = min_val[0] if min_val else None
            
            if isinstance(max_val, list):
                logger.warning(f"Parameter {name} max value is a list: {max_val}, taking first element")
                max_val = max_val[0] if max_val else None
                
            # Convert to appropriate numeric type
            if min_val is not None:
                min_val = float(min_val)
            if max_val is not None:
                max_val = float(max_val)
                
            # If current is a list, handle it too
            if isinstance(current, list):
                logger.warning(f"Parameter {name} current value is a list: {current}, taking first element")
                current = current[0] if current else None
                
            if current is not None:
                current = float(current)
                
        except (ValueError, TypeError) as e:
            logger.error(f"Error processing range values for parameter {name}: {e}")
            continue
        
        # If no range provided, use default range
        if min_val is None or max_val is None:
            if current is not None:
                min_val = current * 0.1
                max_val = current * 10
            else:
                continue  # Skip parameters without enough information
        
        # Ensure min_val < max_val
        if min_val >= max_val:
            logger.warning(f"Parameter {name} min value {min_val} is greater than or equal to max value {max_val}, adjusting range")
            if min_val == max_val:
                max_val = min_val * 1.5 + 0.1  # Avoid equality
            else:
                # Swap values
                min_val, max_val = max_val, min_val
        
        # Create dimension
        try:
            if "integer" in param_type or "int" in param_type:
                dimensions.append(Integer(int(min_val), int(max_val), name=name))
            else:  # Default to float
                dimensions.append(Real(float(min_val), float(max_val), name=name))
            
            param_names.append(name)
            param_details[name] = param
            logger.info(f"Successfully added parameter {name} to search space: [{min_val}, {max_val}]")
        except Exception as e:
            logger.error(f"Error creating search space for parameter {name}: {e}")
            continue
    
    return dimensions, param_names, param_details

def update_code_with_params(code_str, params):
    """
    Update code with new parameter values
    """
    lines = code_str.split('\n')
    updated_lines = lines.copy()
    
    # Iterate through each parameter
    for name, value in params.items():
        # Find the line defining the parameter
        for i, line in enumerate(lines):
            # Look for pattern like "name = value"
            pattern = rf'\b{re.escape(name)}\s*=\s*[^#\n]+'
            match = re.search(pattern, line)
            
            if match:
                # Replace parameter value
                new_line = re.sub(pattern, f"{name} = {value}", line)
                updated_lines[i] = new_line
                break
    
    return '\n'.join(updated_lines)

async def bayesian_optimization(original_code, initial_reward, api_key, bot_name, temp_file_path, run_test_func, process_id):
    """
    Use LLM to identify parameters and perform Bayesian optimization
    """
    # Skip Bayesian optimization if iterations set to 0
    if MAX_BAYESIAN_ITERATIONS <= 0:
        logger.info("Bayesian optimization iterations set to 0, skipping Bayesian optimization")
        return original_code, initial_reward, []
    
    # Use LLM to identify parameters
    params_data = await get_optimization_parameters(api_key, bot_name, original_code)
    
    if not params_data or not params_data.get("parameters"):
        logger.warning("LLM did not identify any optimizable parameters")
        return original_code, initial_reward, []
    
    # Create search space
    dimensions, param_names, param_details = create_search_space_from_llm_params(params_data)
    
    if not dimensions:
        logger.warning("Search space is empty")
        return original_code, initial_reward, []
    
    logger.info(f"Will optimize the following parameters: {param_names}")
    for name in param_names:
        param = param_details[name]
        logger.info(f"- {name}: current={param.get('current')}, range=[{param.get('min')}, {param.get('max')}], description={param.get('description')}")
    
    # Create optimization history record
    optimization_history = []
    best_reward = initial_reward
    best_code = original_code
    best_params = {}
    
    # Define objective function
    @use_named_args(dimensions)
    def objective(**param_values):
        nonlocal best_reward, best_code, best_params
        
        # Update code
        updated_params = {name: param_values[name] for name in param_names}
        updated_code = update_code_with_params(original_code, updated_params)
        
        # Save updated code
        if not update_temp_abr_function(temp_file_path, updated_code):
            logger.error("Cannot update ABR function")
            return -float('inf')  # Return extremely small value to indicate failure
        
        # Run test
        if not run_test_func():
            logger.error("Test run failed")
            return -float('inf')
        
        # Calculate new reward
        new_reward = cal_current_reward(process_id)
        if new_reward is None:
            logger.error("Cannot get new reward")
            return -float('inf')
        
        # Record this result
        result = {
            'params': updated_params,
            'reward': new_reward,
            'code': updated_code
        }
        optimization_history.append(result)
        
        logger.info(f"Parameters: {updated_params}, Reward: {new_reward}")
        
        # Update best result
        if new_reward > best_reward:
            best_reward = new_reward
            best_code = updated_code
            best_params = updated_params
            logger.info(f"Found new best result! Reward: {best_reward}")
        
        # Always return negative of current result (since gp_minimize minimizes the objective function)
        return -new_reward
    
    # Execute Bayesian optimization
    try:
        logger.info("Starting Bayesian optimization...")
        res = gp_minimize(
            objective,
            dimensions,
            n_calls=MAX_BAYESIAN_ITERATIONS,
            random_state=42,
            verbose=True
        )
        
        # Visualize optimization process
        try:
            plt.figure(figsize=(10, 6))
            plot_convergence(res)
            plt.title("Bayesian Optimization Convergence")
            plt.savefig(os.path.join(os.path.dirname(temp_file_path), "bayesian_optimization_convergence.png"))
            logger.info("Saved optimization convergence plot")
        except Exception as e:
            logger.error(f"Failed to create optimization plot: {e}")
        
        # Ensure final code is the best
        if not update_temp_abr_function(temp_file_path, best_code):
            logger.error("Cannot update to best code")
        
        logger.info(f"Bayesian optimization complete. Best parameters: {best_params}, Best reward: {best_reward}")
        return best_code, best_reward, optimization_history
    
    except Exception as e:
        logger.error(f"Error during Bayesian optimization: {e}")
        # Ensure reverting to original code
        update_temp_abr_function(temp_file_path, original_code)
        return original_code, initial_reward, optimization_history

async def optimize_code():
    global ALGORITHM_NAME, logger
    
    # Use process ID to create unique working directory
    process_id = os.getpid()
    process_log_dir = os.path.join(log_dir, f"process_{process_id}")
    os.makedirs(process_log_dir, exist_ok=True)
    logger.info(f"Created working directory for process {process_id}: {process_log_dir}")
    
    # Set algorithm name and original file
    original_abr_file = f"{ALGORITHM_NAME}.py"
    logger.info(f"Using algorithm name: {ALGORITHM_NAME}")
    logger.info(f"Using ABR algorithm file: {original_abr_file}")
    
    # Check if file exists
    if not os.path.exists(original_abr_file):
        logger.error(f"ABR file {original_abr_file} does not exist, exiting program")
        return
    
    # Create temporary algorithm file
    temp_abr_file = os.path.join(process_log_dir, f"temp_{ALGORITHM_NAME}_{process_id}.py")
    shutil.copy2(original_abr_file, temp_abr_file)
    logger.info(f"Created temporary algorithm file: {temp_abr_file}")
    
    # Save backup of original algorithm file
    original_file_backup = os.path.join(process_log_dir, f"original_{ALGORITHM_NAME}.py")
    shutil.copy2(original_abr_file, original_file_backup)
    logger.info(f"Created backup of original algorithm file: {original_file_backup}")
    
    # Define internal function to run test script
    def run_test_with_temp_file():
        """Run test script using temporary file"""
        try:
            # Modified: Pass process ID as fourth parameter to test script
            os.system(f'python run_abr.py {dataset}/ {ALGORITHM_NAME} {temp_abr_file} {process_id}')
            logger.info(f"Successfully ran test script run_abr.py using algorithm {ALGORITHM_NAME} and temp file {temp_abr_file}, process ID: {process_id}")
            return True
        except Exception as e:
            logger.error(f"Error running test script: {e}")
            return False
    
    # API configuration
    api_key = "yout_key"
    bot_name = 'Claude-3.7-Sonnet'
    
    # Step 1: Run test
    if not run_test_with_temp_file():
        logger.error("Failed to run test script, exiting program")
        return
    
    # Calculate initial reward
    initial_reward = cal_current_reward(process_id)
    if initial_reward is None:
        logger.error("Cannot get initial reward, exiting program")
        return
    
    logger.info(f"Initial reward: {initial_reward}")
    
    # Step 2: Calculate and save badcases
    calculate_badcases_from_results(process_log_dir, temp_abr_file, process_id)
    
    # Read initial abr function code
    initial_abr_code = get_abr_function_content(temp_abr_file)
    if not initial_abr_code:
        logger.error(f"Cannot extract abr function from {temp_abr_file}, exiting program")
        return
    
    # Read badcase file
    original_badcase = ""
    badcase_path = os.path.join(process_log_dir, 'badcase.txt')
    if os.path.exists(badcase_path):
        with open(badcase_path, 'r') as f:
            original_badcase = f.read()
    
    # Modified: First perform Bayesian optimization as baseline for round 0
    baseline_reward = initial_reward
    baseline_code = initial_abr_code
    
    if MAX_BAYESIAN_ITERATIONS > 0:
        logger.info("Round 0: Performing baseline Bayesian optimization...")
        baseline_code, baseline_reward, baseline_history = await bayesian_optimization(
            initial_abr_code, initial_reward, api_key, bot_name, temp_abr_file, run_test_with_temp_file, process_id
        )
        
        # Save baseline Bayesian optimization results
        bayes_baseline_file = os.path.join(process_log_dir, "baseline_bayesian_optimization.json")
        with open(bayes_baseline_file, 'w', encoding='utf-8') as f:
            baseline_data = {
                'initial_reward': float(initial_reward),
                'baseline_reward': float(baseline_reward),
                'improvement': float(baseline_reward - initial_reward),
                'improvement_percentage': float((baseline_reward - initial_reward) / initial_reward * 100)
            }
            json.dump(baseline_data, f, ensure_ascii=False, cls=NumpyEncoder, indent=2)
        
        logger.info(f"Baseline Bayesian optimization complete. Initial reward: {initial_reward}, Optimized reward: {baseline_reward}")
        logger.info(f"Improvement: {baseline_reward - initial_reward} ({(baseline_reward - initial_reward) / initial_reward * 100:.2f}%)")
    else:
        logger.info("Bayesian optimization iterations set to 0, skipping baseline Bayesian optimization")
    
    current_reward = baseline_reward
    current_abr_code = baseline_code
    previous_reward = current_reward  # To record previous round's reward
    iteration = 1
    
    # Save records of performance-decreasing adjustments and successful records
    failed_attempts = []
    successful_attempts = []
    bayesian_history = []
    
    # Get bandwidth information for current test environment
    bandwidth_info = NETWORK_INFO.get(dataset, "Unknown bandwidth range")
    
    # ======== LLM-assisted optimization loop ========
    while iteration <= MAX_LLM_ITERATIONS:
        logger.info(f"\n===== Optimization Iteration {iteration} =====")
        
        # Build prompt
        badcase_section = ""
        if os.path.exists(os.path.join(process_log_dir, 'badcase.txt')):
            with open(os.path.join(process_log_dir, 'badcase.txt'), 'r') as f:
                badcase_section = f"Current badcase information:\n{f.read()}\n\n"
        
        # Add detailed information for each failed attempt
        failed_section = ""
        for i, attempt in enumerate(failed_attempts):
            failed_section += f"Failed attempt {i+1}:\n"
            failed_section += f"Code:\n```python\n{attempt['code']}\n```\n"
            failed_section += f"reward: {attempt['reward']} (original reward: {attempt['previous_reward']})\n"
            if 'specific_badcase' in attempt and attempt['specific_badcase']:
                failed_section += f"Corresponding badcase:\n{attempt['specific_badcase']}\n\n"
        
        # Add detailed information for each successful attempt
        successful_section = ""
        for i, attempt in enumerate(successful_attempts):
            successful_section += f"Successful attempt {i+1}:\n"
            successful_section += f"Code:\n```python\n{attempt['code']}\n```\n"
            successful_section += f"Modification reason: {attempt['reasoning']}\n"
            successful_section += f"reward: {attempt['reward']} (original reward: {attempt['previous_reward']})\n\n"
        
        # Add Bayesian optimization information
        bayes_section = ""
        if bayesian_history:
            bayes_section = f"""
            Last Bayesian optimization results:
            Pre-optimization reward: {bayesian_history[-1]['previous_reward']}
            Post-optimization reward: {bayesian_history[-1]['reward']}
            Best parameters found: {bayesian_history[-1]['best_params'] if 'best_params' in bayesian_history[-1] else "None"}
            """
        
        # Modified: Provide baseline Bayesian optimization information
        baseline_section = ""
        if MAX_BAYESIAN_ITERATIONS > 0 and baseline_reward > initial_reward:
            baseline_section = f"""
            Baseline Bayesian optimization results:
            Initial reward: {initial_reward}
            Post-Bayesian optimization reward: {baseline_reward}
            Improvement percentage: {(baseline_reward - initial_reward) / initial_reward * 100:.2f}%
            """
        
        prompt = f'''I need you to optimize a bitrate adaptive algorithm. There are 6 bitrate levels 0-5, corresponding to [300,750,1200,1850,2850,4300]kbps. The testing environment bandwidth range is {bandwidth_info} Mbps.

        I only need you to optimize the implementation of the abr function, from def abr until the end of the function, keeping the function inputs unchanged:

        Current Python code for abr function:
        ```python
        {current_abr_code}
        ```
        The VIDEO_BIT_RATE, chunk_length, and A_DIM constants are already defined in the file header, you don't need to modify or include them. Just modify the abr function itself.

        Current reward: {current_reward}
        Initial reward: {initial_reward}
        {f"Previous round reward: {previous_reward}" if iteration > 1 else ""}
        
        {baseline_section if baseline_section else ""}

        {badcase_section}

        {failed_section if failed_section else ""}

        {'' if not successful_section else 'Here are previous successful modifications and their reasons:' + chr(10) + successful_section}

        {bayes_section if bayes_section else ""}

        Please analyze the code and suggest improvements to increase the reward. First explain in detail your improvement ideas and reasons, then provide the complete modified abr function code.

        Format requirements:

        First provide detailed analysis and reasons for modification
        Then provide complete function code from def abr to the end of the function using ```python code block
        '''
        
        # Call LLM to get optimization suggestions
        message = fp.ProtocolMessage(role="user", content=prompt)
        response = await get_responses(api_key, [message], bot_name)
        logger.info(f"LLM response:\n{response}")
        
        # Extract Python code and reasoning from response
        optimized_code = extract_code_from_response(response)
        reasoning = extract_reasoning_from_response(response)
        
        # Print modification reasons provided by LLM
        logger.info(f"Modification reasons provided by LLM:\n{reasoning}")
        
        # Save LLM optimized code version
        llm_optimized_code_file = os.path.join(process_log_dir, f"optimized_llm_{iteration}.py")
        with open(llm_optimized_code_file, 'w', encoding='utf-8') as f:
            f.write(optimized_code)
            
        # Save modification reasons
        reasoning_file = os.path.join(process_log_dir, f"reasoning_llm_{iteration}.txt")
        with open(reasoning_file, 'w', encoding='utf-8') as f:
            f.write(reasoning)
        
        # Update temporary abr function file
        if not update_temp_abr_function(temp_abr_file, optimized_code):
            logger.error(f"Iteration {iteration} cannot update abr function, skipping this iteration")
            iteration += 1
            continue
        
        # Run test to get new reward
        if not run_test_with_temp_file():
            logger.error(f"Iteration {iteration} test run failed, using previous version")
            update_temp_abr_function(temp_abr_file, current_abr_code)
            failed_attempts.append({
                'code': optimized_code,
                'reasoning': reasoning,
                'reward': 'Run failed',
                'previous_reward': current_reward,
                'specific_badcase': f"Iteration {iteration} - Test run failed"
            })
            iteration += 1
            continue
            
        llm_reward = cal_current_reward(process_id)
        if llm_reward is None:
            logger.error(f"Iteration {iteration} cannot get reward, using previous version")
            # Revert to previous version
            update_temp_abr_function(temp_abr_file, current_abr_code)
            # Add to failed records
            failed_attempts.append({
                'code': optimized_code,
                'reasoning': reasoning,
                'reward': 'Cannot obtain',
                'previous_reward': current_reward,
                'specific_badcase': f"Iteration {iteration} - Cannot get reward value"
            })
            iteration += 1
            continue
        
        logger.info(f"LLM optimized reward: {llm_reward}, previous reward: {current_reward}")
        
        # Record LLM optimization results
        llm_result = {
            'code': optimized_code,
            'reasoning': reasoning,
            'reward': float(llm_reward),
            'previous_reward': float(current_reward)
        }
        
        # Save LLM optimization results
        with open(os.path.join(process_log_dir, f'llm_result_{iteration}.json'), 'w', encoding='utf-8') as f:
            json.dump(llm_result, f, ensure_ascii=False, cls=NumpyEncoder, indent=2)
        
        # Modified: Directly perform Bayesian optimization (if enabled)
        bayesian_code = optimized_code
        bayesian_reward = llm_reward
        
        if MAX_BAYESIAN_ITERATIONS > 0:
            logger.info("Starting Bayesian optimization...")
            
            # Execute Bayesian optimization
            bayesian_code, bayesian_reward, optimization_history = await bayesian_optimization(
                optimized_code, llm_reward, api_key, bot_name, temp_abr_file, run_test_with_temp_file, process_id
            )
            
            # Record Bayesian optimization history
            best_params = {}
            if optimization_history:
                # Find best parameters
                best_result = max(optimization_history, key=lambda x: x['reward'])
                # Convert all NumPy types to standard Python types
                for param, value in best_result['params'].items():
                    if isinstance(value, (np.integer, np.floating)):
                        best_params[param] = float(value) if isinstance(value, np.floating) else int(value)
                    else:
                        best_params[param] = value
            
            # Ensure all values are JSON serializable
            history_for_json = []
            for record in optimization_history:
                json_record = {'reward': float(record['reward'])}
                params_dict = {}
                for param, value in record['params'].items():
                    if isinstance(value, (np.integer, np.floating)):
                        params_dict[param] = float(value) if isinstance(value, np.floating) else int(value)
                    else:
                        params_dict[param] = value
                json_record['params'] = params_dict
                history_for_json.append(json_record)
            
            bayesian_entry = {
                'previous_reward': float(llm_reward),
                'reward': float(bayesian_reward),
                'best_params': best_params,
                'history': history_for_json
            }
            
            bayesian_history.append(bayesian_entry)
            
            # Save Bayesian optimization history
            with open(os.path.join(process_log_dir, f'bayesian_history_{iteration}.json'), 'w', encoding='utf-8') as f:
                json.dump(bayesian_entry, f, ensure_ascii=False, cls=NumpyEncoder, indent=2)
            
            # Save Bayesian optimized code
            bayesian_code_file = os.path.join(process_log_dir, f"optimized_bayes_{iteration}.py")
            with open(bayesian_code_file, 'w', encoding='utf-8') as f:
                f.write(bayesian_code)
            
            logger.info(f"LLM optimized reward: {llm_reward}, Bayesian optimized reward: {bayesian_reward}")
        
        # Modified: Compare optimized reward (Bayesian or LLM) with current reward to decide whether to adopt this optimization
        optimized_reward = bayesian_reward if MAX_BAYESIAN_ITERATIONS > 0 else llm_reward
        optimized_code = bayesian_code if MAX_BAYESIAN_ITERATIONS > 0 else optimized_code
        
        if optimized_reward > current_reward:
            logger.info(f"Optimization improved performance from {current_reward} to {optimized_reward}")
            previous_reward = current_reward  # Save previous round's reward
            current_abr_code = optimized_code
            current_reward = optimized_reward
            
            successful_attempts.append({
                'code': optimized_code,
                'reasoning': reasoning,
                'reward': optimized_reward,
                'previous_reward': previous_reward,
                'llm_reward': llm_reward,
                'bayesian_reward': bayesian_reward if MAX_BAYESIAN_ITERATIONS > 0 else None
            })
        else:
            logger.info(f"Optimization did not improve performance, keeping original code")
            # Revert to previous version
            update_temp_abr_function(temp_abr_file, current_abr_code)
            failed_attempts.append({
                'code': optimized_code,
                'reasoning': reasoning,
                'reward': optimized_reward,
                'previous_reward': current_reward,
                'llm_reward': llm_reward,
                'bayesian_reward': bayesian_reward if MAX_BAYESIAN_ITERATIONS > 0 else None
            })
        
        # Recalculate badcases (based on latest code)
        calculate_badcases_from_results(process_log_dir, temp_abr_file, process_id)
        
        # If optimized reward reaches certain threshold, consider early termination
        if current_reward / initial_reward > 1.5:  # For example, if reward improved by more than 50%
            logger.info(f"\nOptimization target reached! Final reward: {current_reward} (Initial reward: {initial_reward})")
            break
        
        iteration += 1
    
    # Save final optimized version to log directory
    final_file_path = os.path.join(process_log_dir, f'final_optimized_{ALGORITHM_NAME}.py')
    with open(final_file_path, 'w', encoding='utf-8') as f:
        # Get original constants part
        with open(original_abr_file, 'r', encoding='utf-8') as original:
            content = original.read()
            const_pattern = r'(VIDEO_BIT_RATE\s*=.*\nchunk_length\s*=.*\nA_DIM\s*=.*\n)'
            const_match = re.search(const_pattern, content)
            constants = const_match.group(1) if const_match else ""
        
        # Write constants and optimized function
        f.write(constants + current_abr_code)
    
    # Save all attempt records to file, ensuring Chinese displays correctly
    optimization_history = {
        'initial_reward': float(initial_reward),
        'baseline_reward': float(baseline_reward),
        'final_reward': float(current_reward),
        'original_code': initial_abr_code,
        'baseline_code': baseline_code,
        'final_code': current_abr_code,
        'bayesian_history': bayesian_history,
        'successful_attempts': successful_attempts,
        'failed_attempts': failed_attempts,
        'algorithm_name': ALGORITHM_NAME,
        'abr_file': temp_abr_file,
        'process_id': process_id  # Add process ID to history record
    }
    
    with open(os.path.join(process_log_dir, 'complete_optimization_history.json'), 'w', encoding='utf-8') as f:
        json.dump(optimization_history, f, ensure_ascii=False, cls=NumpyEncoder, indent=2)
    
    # Create optimization results summary chart
    try:
        # Modified: Include baseline Bayesian optimization reward
        rewards = [initial_reward, baseline_reward]
        labels = ['Initial', 'Baseline Bayes']
        
        for i, attempt in enumerate(successful_attempts):
            if 'llm_reward' in attempt and 'bayesian_reward' in attempt and attempt['bayesian_reward'] is not None:
                # Add LLM and Bayesian results
                rewards.append(attempt['llm_reward'])
                rewards.append(attempt['bayesian_reward'])
                labels.append(f'LLM-{i+1}')
                labels.append(f'Bayes-{i+1}')
            else:
                rewards.append(attempt['reward'])
                labels.append(f'LLM-{i+1}')
        
        plt.figure(figsize=(12, 7))
        plt.bar(labels, rewards, color='skyblue')
        plt.title('ABR Algorithm Optimization Progress')
        plt.ylabel('Reward')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.xticks(rotation=45)
        
        # Add value labels
        for i, v in enumerate(rewards):
            plt.text(i, v + 0.02, f"{v:.4f}", ha='center')
        
        plt.tight_layout()
        plt.savefig(os.path.join(process_log_dir, 'optimization_progress.png'))
        logger.info("Saved optimization progress chart")
    except Exception as e:
        logger.error(f"Failed to create optimization results chart: {e}")
    
    # Clean up temporary files
    try:
        if os.path.exists(temp_abr_file):
            os.remove(temp_abr_file)
            logger.info(f"Deleted temporary file: {temp_abr_file}")
    except Exception as e:
        logger.error(f"Error cleaning up temporary file: {e}")
    # Clean up generated log files
    try:
        # Find and delete all log files named with this process ID
        log_pattern = f'log_sim_{ALGORITHM_NAME}_{process_id}_*'
        log_files = glob.glob(os.path.join(dataset_dir, log_pattern))
        for log_file in log_files:
            os.remove(log_file)
            logger.info(f"Deleted log file: {log_file}")
    except Exception as e:
        logger.error(f"Error cleaning up log files: {e}")
    logger.info(f"Optimization complete! Initial reward: {initial_reward}, Baseline Bayesian reward: {baseline_reward}, Final reward: {current_reward}")
    logger.info(f"Final improvement ratio: {(current_reward - initial_reward) / initial_reward * 100:.2f}%")
    logger.info(f"Saved final optimized version to {final_file_path}")
    logger.info(f"History record of all optimization attempts saved to {os.path.join(process_log_dir, 'complete_optimization_history.json')}")

if __name__ == "__main__":
    # Handle command line arguments
    if len(sys.argv) < 5:
        print("Usage: python script.py <dataset_name> <max_bayesian_iterations> <max_llm_iterations> <algorithm_name>")
        sys.exit(1)
    
    # Parse command line arguments
    dataset = sys.argv[1]
    MAX_BAYESIAN_ITERATIONS = int(sys.argv[2])
    MAX_LLM_ITERATIONS = int(sys.argv[3])
    ALGORITHM_NAME = sys.argv[4]
    
    # Update dataset directory
    dataset_dir = os.path.join(results_dir, dataset)
    rl_log_dir =  'rl_path'
    
    # Create log directory
    log_dir = os.path.join("log2", f"{dataset}/{MAX_BAYESIAN_ITERATIONS}_{ALGORITHM_NAME}_{MAX_LLM_ITERATIONS}")
    os.makedirs(log_dir, exist_ok=True)
    
    # Setup logging
    logger = setup_logging(log_dir)
    
    # Record parameter information
    logger.info(f"Running parameters - Dataset: {dataset}, Max Bayesian iterations: {MAX_BAYESIAN_ITERATIONS}, Algorithm: {ALGORITHM_NAME}, Max LLM iterations: {MAX_LLM_ITERATIONS}")
    logger.info(f"Logs saved in: {log_dir}")
    
    # Run optimization
    asyncio.run(optimize_code())