#!/usr/bin/env python3
"""
Updated FDA search weight optimization script
- Removed 'summary' and 'concepts' from optimization variables.
- Uses Optuna TPESampler(multivariate=True)
- Adds MedianPruner-based early stopping and reports intermediate results for pruning
- Implements a "dropout" strategy: keeps a cached base set of NUM_TEST_CASES and each run replaces DROPOUT_COUNT items with newly generated queries (only generates queries for replacements)
- Logs trial-level results to CSV and produces plotting at the end (matplotlib)

NOTE: This script assumes `create_embeddings.py` is available and prints results in a parseable format as earlier.
"""

import optuna
import csv
import subprocess
import json
import os
import re
import random
import ast
import sys
import time
import math
from typing import List, Dict, Any
import pandas as pd
import matplotlib.pyplot as plt

# --- Configuration ---
RESULTS_DIR = os.path.join(os.getcwd(), "optimization_results")
os.makedirs(RESULTS_DIR, exist_ok=True)
BASE_TESTCASE_CACHE = os.path.join(RESULTS_DIR, "base_test_cases.json")
TRIAL_RESULTS_CSV = os.path.join(RESULTS_DIR, "trial_results.csv")

FDA_RECORDS_CSV = '/Users/arun/Documents/fda-search/py_src/fda_ai_records.csv'
CREATE_EMBEDDINGS_PY = '/Users/arun/Documents/fda-search/py_src/create_embeddings.py'
NUM_TEST_CASES = 50  # total test set size
DROPOUT_COUNT = 10   # number of items to replace per run (dropout)
N_OPTUNA_TRIALS = 2424  # Number of optimization trials for Optuna (user requested)
TOP_N_RESULTS_FOR_PASS = 5 # Consider a test passed if the expected result is in top N
OUTPUT_DIR = '/Users/arun/Documents/fda-search/embedding_data_small'  # Output directory for search

# Updated weight keys to match create_embeddings.py
WEIGHT_KEYS = ['keywords', 'questions', 'thesis', 'search_boost', 'query_match_1', 'query_match_2', 'query_match_3']

# Seed weights from the GA file (kept for reference)
SEED_WEIGHTS_DICT = {
    'summary': 0, 'keywords': 0.07, 'questions': 0.03,
    'concepts': 0.0, 'thesis': 0.1, 'search_boost': 0.1,
    'query_match_1': 0.25, 'query_match_2': 0.25, 'query_match_3': 0.2
}

# Global variable for test cases
TEST_CASES: List[Dict[str, Any]] = []

# Global log file handle, to be set in main
LOG_FILE_HANDLE = None

# Ollama model used for query generation (if available)
OLLAMA_MODEL_NAME = "gemma3n:e2b"

# --- Helpers ---

def log_f(message: str):
    """Logs a message to both console and the global log file."""
    print(message)
    if LOG_FILE_HANDLE:
        try:
            LOG_FILE_HANDLE.write(message + '\n')
            LOG_FILE_HANDLE.flush()
        except Exception:
            pass


def load_all_records(csv_path: str) -> List[Dict[str,str]]:
    all_records = []
    try:
        with open(csv_path, mode='r', encoding='utf-8') as infile:
            reader = csv.DictReader(infile)
            for row in reader:
                # Ensure essential fields for a test case are present
                if row.get('submission_number') and \
                   (row.get('thesis') or row.get('concepts') or \
                    row.get('summary_keywords') or row.get('device_model')):
                    all_records.append(row)
    except FileNotFoundError:
        log_f(f"Error: CSV file {csv_path} not found.")
    except Exception as e:
        log_f(f"Error reading CSV file {csv_path}: {e}")
    return all_records


def run_ollama_for_query(thesis: str, concepts_str: str, timeout: int = 30) -> str:
    prompt_for_cli = f'''You are an expert medical researcher. Based on the following thesis and key concepts from a medical device's FDA summary, generate a concise and clinically relevant search query that a clinician or researcher might use to find information about similar devices or technologies. Do NOT include anything about AI, ML, or what those mean. Only return the 1-3 word medical search query. Return only the query itself, without any preamble or explanation..
Thesis: "{thesis}"
Key Concepts: "{concepts_str}"
Clinically Relevant Search Query:'''

    command = ["ollama", "run", OLLAMA_MODEL_NAME, prompt_for_cli]
    try:
        process = subprocess.run(command, capture_output=True, text=True, check=False, timeout=timeout)
        if process.returncode == 0:
            generated_query = process.stdout.strip()
            if generated_query:
                lines = generated_query.splitlines()
                # pick last non-empty, non-ollama banner line
                for line in reversed(lines):
                    clean_line = line.strip()
                    if clean_line and not clean_line.startswith('>>>') and OLLAMA_MODEL_NAME not in clean_line:
                        return clean_line
                return lines[-1].strip() if lines else ""
            return ""
        else:
            return ""
    except Exception:
        return ""


def generate_query_for_record(record: Dict[str,str]) -> Dict[str,str]:
    """Generates a query string for a CSV record (tries Ollama then fallbacks).
    Returns a dictionary with keys: 'query' and 'query_source'.
    """
    thesis = (record.get('thesis') or '').strip()
    concepts_str = (record.get('concepts') or '').strip()
    keywords_str = (record.get('summary_keywords') or '').strip()
    device_name = record.get('device_model') or 'Unknown Device'

    query = ""
    generated_by_ollama = False

    if (thesis or concepts_str):
        q = run_ollama_for_query(thesis, concepts_str)
        if q:
            query = q
            generated_by_ollama = True

    # Fallbacks
    if not query:
        if thesis and len(thesis.split()) >= 7:
            query = " ".join(thesis.split()[:random.randint(7,12)])
        elif thesis and len(thesis.split()) >= 5:
            query = " ".join(thesis.split()[:random.randint(5,7)])

    if (not query) or len(query.split()) < 4:
        if concepts_str:
            concept_list = [c.strip() for c in concepts_str.split(',') if c.strip()]
            if concept_list:
                num_to_sample = min(len(concept_list), random.randint(1,2))
                query = " ".join(random.sample(concept_list, num_to_sample))

    if (not query) or len(query.split()) < 3:
        if keywords_str:
            keyword_list = [k.strip() for k in keywords_str.split(',') if k.strip()]
            if keyword_list:
                num_to_sample = min(len(keyword_list), random.randint(2,3))
                query = " ".join(random.sample(keyword_list, num_to_sample))

    if not query:
        if device_name != 'Unknown Device' and device_name.strip():
            query = f"{device_name} medical device"
        else:
            query = "medical device information"

    return {'query': query, 'query_source': 'ollama' if generated_by_ollama else 'fallback'}


def generate_test_cases_from_csv(csv_path: str, num_cases: int, dropout_count: int = DROPOUT_COUNT) -> List[Dict[str,Any]]:
    """
    Behavior:
    - If a base test-case cache (BASE_TESTCASE_CACHE) exists: load it. That cache contains NUM_TEST_CASES items including pre-generated 'query'.
    - Perform a "dropout" of `dropout_count` indices chosen at random; for those indices, select replacement records from the CSV that are not already in the base set and generate queries for only those replacements.
    - If the cache does not exist, deterministically sample a base set (random.seed(42)) of `num_cases` records from the CSV, generate queries for all of them (slow), save cache, and return the set.
    """
    all_records = load_all_records(csv_path)
    if not all_records:
        log_f("No suitable records found in CSV to generate test cases.")
        return []

    # Load or create base cache
    base_records: List[Dict[str,Any]] = []
    if os.path.exists(BASE_TESTCASE_CACHE):
        try:
            with open(BASE_TESTCASE_CACHE, 'r', encoding='utf-8') as fh:
                base_records = json.load(fh)
            # Validate length
            if not isinstance(base_records, list) or len(base_records) != num_cases:
                log_f("Base cache exists but is invalid length. Recreating base cache.")
                base_records = []
        except Exception as e:
            log_f(f"Failed to read base cache: {e}. Recreating base cache.")
            base_records = []

    if not base_records:
        # Deterministic selection for base set to ensure repeatability across runs
        random.seed(42)
        if len(all_records) < num_cases:
            selected = all_records
            log_f(f"Warning: Requested {num_cases} test cases, but only {len(all_records)} suitable records available. Using all.")
        else:
            selected = random.sample(all_records, num_cases)

        # Generate queries for all base records (this may call Ollama)
        base_records = []
        for rec in selected:
            gen = generate_query_for_record(rec)
            stored = {
                'submission_number': rec['submission_number'],
                'device_model': rec.get('device_model', ''),
                'thesis': rec.get('thesis',''),
                'concepts': rec.get('concepts',''),
                'summary_keywords': rec.get('summary_keywords',''),
                'query': gen['query'],
                'query_source': gen['query_source']
            }
            base_records.append(stored)

        # Save cache
        try:
            with open(BASE_TESTCASE_CACHE, 'w', encoding='utf-8') as fh:
                json.dump(base_records, fh, indent=2)
            log_f(f"Saved base test-case cache to {BASE_TESTCASE_CACHE}")
        except Exception as e:
            log_f(f"Warning: Could not save base test-case cache: {e}")

    # Now perform dropout replacements for this run
    if dropout_count <= 0:
        return base_records

    if dropout_count >= num_cases:
        log_f("Dropout count >= num_cases; regenerating entire set for this run.")
        # fallback: just regenerate entirely (non-deterministic)
        random.seed(None)
        selected = random.sample(all_records, num_cases) if len(all_records) >= num_cases else all_records
        run_records = []
        for rec in selected:
            gen = generate_query_for_record(rec)
            run_records.append({
                'submission_number': rec['submission_number'],
                'device_model': rec.get('device_model',''),
                'thesis': rec.get('thesis',''),
                'concepts': rec.get('concepts',''),
                'summary_keywords': rec.get('summary_keywords',''),
                'query': gen['query'],
                'query_source': gen['query_source']
            })
        return run_records

    # choose dropout indices non-deterministically
    random.seed(None)
    dropout_indices = random.sample(range(num_cases), dropout_count)

    # Build a set of submission_numbers currently in base to avoid picking duplicates
    current_submissions = {rec['submission_number'] for rec in base_records}

    # Pool of candidates for replacements: records not in current_submissions
    candidates = [r for r in all_records if r['submission_number'] not in current_submissions]
    if len(candidates) < dropout_count:
        log_f("Not enough candidate records to replace dropout_count items; reducing dropout count.")
        dropout_indices = dropout_indices[:max(0, len(candidates))]

    # Create run_records as a deep copy of base_records
    run_records = [dict(r) for r in base_records]

    for di in dropout_indices:
        # Pick a random candidate record to replace
        if not candidates:
            log_f("No candidates left for replacement; stopping replacements early.")
            break
        replacement = random.choice(candidates)
        candidates.remove(replacement)  # avoid reuse
        gen = generate_query_for_record(replacement)
        run_records[di] = {
            'submission_number': replacement['submission_number'],
            'device_model': replacement.get('device_model',''),
            'thesis': replacement.get('thesis',''),
            'concepts': replacement.get('concepts',''),
            'summary_keywords': replacement.get('summary_keywords',''),
            'query': gen['query'],
            'query_source': gen['query_source']
        }

    replaced_count = len(dropout_indices)
    log_f(f"Using {num_cases} test cases with dropout: replaced {replaced_count} items this run.")
    return run_records


def run_command_with_progress(command_list, step_name="Command", timeout_seconds=600):
    """
    Runs a command and prints basic progress. (Uses log_f for output)
    Returns (success_bool, stdout_str, stderr_str).
    """
    process_start_time = time.time()
    try:
        process = subprocess.run(command_list, capture_output=True, text=True, check=False, timeout=timeout_seconds)
        duration = time.time() - process_start_time
        if process.returncode != 0:
            return False, process.stdout, process.stderr
        return True, process.stdout, process.stderr
    except subprocess.TimeoutExpired:
        duration = time.time() - process_start_time
        return False, "", f"{step_name} timed out after {timeout_seconds} seconds."
    except Exception as e:
        duration = time.time() - process_start_time
        return False, "", str(e)


def run_search_for_optimization(query: str, weights_for_search_json: str) -> List[Dict[str,Any]]:
    """
    Runs create_embeddings.py with --query for a given query and specific weights,
    returning top N submission numbers with similarity.
    """
    command = ['python', CREATE_EMBEDDINGS_PY, '--query', query, '--weights_json', weights_for_search_json, '--csv_path', FDA_RECORDS_CSV, '--output_dir', OUTPUT_DIR]

    try:
        process = subprocess.run(
            command,
            capture_output=True, text=True, check=False, timeout=60 # 1 min timeout per search
        )

        if process.returncode != 0:
            return []

        # Parse the output from create_embeddings.py
        try:
            output_lines = process.stdout.strip().split('\n')
            parsed_results = []
            results_started = False
            for line in output_lines:
                line = line.strip()
                if "hybrid results for" in line.lower():
                    results_started = True
                    continue
                if results_started and line and line[0].isdigit():
                    try:
                        parts = line.split(',')
                        if len(parts) >= 2:
                            submission_part = parts[0].split('Submission:')
                            if len(submission_part) >= 2:
                                submission_number = submission_part[1].strip()
                                similarity_part = parts[1].split('Hybrid Similarity:')
                                similarity_score = None
                                if len(similarity_part) >= 2:
                                    try:
                                        similarity_score = float(similarity_part[1].strip().split()[0])
                                    except Exception:
                                        similarity_score = None
                                parsed_results.append({
                                    "submissionNumber": submission_number,
                                    "similarity": similarity_score
                                })
                    except Exception:
                        continue
                if results_started and (not line or "No results found" in line):
                    break
            return parsed_results[:TOP_N_RESULTS_FOR_PASS + 2]
        except Exception:
            return []

    except subprocess.TimeoutExpired:
        return []
    except Exception:
        return []


def evaluate_search_performance_for_trial(current_test_cases: List[Dict[str,Any]], trial_obj: optuna.trial.Trial, trial_weights_json: str) -> float:
    """
    Evaluates search performance for the current set of weights by running all test cases.
    This version reports intermediate results to `trial_obj` to enable pruning.
    Returns the success rate (0.0 to 1.0).
    """
    passed_count = 0
    if not current_test_cases:
        return 0.0

    total_cases = len(current_test_cases)
    log_f(f"    Detailed results for Trial {trial_obj.number} (reporting for pruning):")

    for i, case in enumerate(current_test_cases):
        query = case['query']
        expected_submission = case['submission_number']
        device_name = case.get('device_model', 'N/A')
        query_source = case.get('query_source', 'unknown')

        search_results_with_similarity = run_search_for_optimization(query, trial_weights_json)
        search_results_submissions_only = [res['submissionNumber'] for res in search_results_with_similarity if 'submissionNumber' in res]

        is_pass = False
        if search_results_submissions_only:
            is_pass = expected_submission in search_results_submissions_only[:TOP_N_RESULTS_FOR_PASS]
        if is_pass:
            passed_count += 1

        # Logging each test case briefly
        log_f(f"      Test Case {i+1}/{total_cases}: Query='{query}' (Source: {query_source}, Expected: {expected_submission} for '{device_name}')")
        if not search_results_with_similarity:
            log_f(f"        -> No results returned.")
        else:
            for rank, res_detail in enumerate(search_results_with_similarity[:TOP_N_RESULTS_FOR_PASS + 2]):
                sim_score_val = res_detail.get('similarity')
                sim_score_str = f"{sim_score_val:.4f}" if sim_score_val is not None else "N/A"
                sub_num = res_detail.get('submissionNumber', 'Error: No SubNum')
                log_f(f"        -> Rank {rank+1}: {sub_num} (Similarity: {sim_score_str})")
        log_f(f"        -> Outcome: {'PASS' if is_pass else 'FAIL'}\n")

        # Report intermediate success rate for pruning
        current_success_rate = passed_count / (i+1)
        try:
            trial_obj.report(current_success_rate, i+1)
            if trial_obj.should_prune():
                log_f(f"    Pruning trial {trial_obj.number} at step {i+1} (intermediate success: {current_success_rate:.4f})")
                raise optuna.exceptions.TrialPruned()
        except optuna.exceptions.TrialPruned:
            raise
        except Exception:
            # Some samplers may not support reporting — ignore
            pass

    success_rate = passed_count / total_cases if total_cases else 0.0
    return success_rate


# --- Optuna Objective Function ---

def objective(trial: optuna.trial.Trial) -> float:
    global TEST_CASES

    # 1. Suggest weights for each field BUT enforce forced zeros
    weights_config = {}
    for key in WEIGHT_KEYS:
        # Suggest weight for each key
        weights_config[f'w_{key}'] = trial.suggest_float(f'w_{key}', 0.01, 0.5)

    total_weight = sum(weights_config.values())
    if total_weight <= 0:
        # Unlikely, but fallback to uniform across non-zero keys
        non_zero_keys = [k for k in weights_config.keys() if weights_config[k] > 0]
        if non_zero_keys:
            normalized_weights = {k: (1.0/len(non_zero_keys) if k in non_zero_keys else 0.0) for k in weights_config.keys()}
        else:
            normalized_weights = {k: 1.0/len(weights_config) for k in weights_config}
    else:
        normalized_weights = {k: (v / total_weight) for k, v in weights_config.items()}

    # Convert normalized_weights to the format expected by create_embeddings.py (remove 'w_' prefix)
    weights_for_search_script = {key.replace('w_', ''): float(value) for key, value in normalized_weights.items()}
    # Ensure forced zeros are explicitly present
    # for z in FORCED_ZERO_KEYS:
    #     weights_for_search_script[z] = 0.0

    trial_weights_json = json.dumps(weights_for_search_script)

    # Logging
    weights_log_str = ", ".join([f"{k.replace('w_','')[:6]}={v:.3f}" for k,v in normalized_weights.items()])
    log_f(f"Trial {trial.number}: Trying weights: {weights_log_str}")

    eval_start_time = time.time()
    try:
        success_rate = evaluate_search_performance_for_trial(TEST_CASES, trial, trial_weights_json)
    except optuna.exceptions.TrialPruned:
        # Log pruned trial and return a negative/low objective so optimizer knows
        log_f(f"Trial {trial.number} was pruned.")
        # Optionally record pruned trial to CSV with a marker
        try:
            with open(TRIAL_RESULTS_CSV, 'a', encoding='utf-8') as csvf:
                # If file empty, write header
                if os.path.getsize(TRIAL_RESULTS_CSV) == 0:
                    header = ['trial_number','status','success_rate'] + [k for k in weights_for_search_script.keys()]
                    csvf.write(','.join(header) + '\n')
                row = [str(trial.number), 'PRUNED', ''] + [str(weights_for_search_script[k]) for k in weights_for_search_script.keys()]
                csvf.write(','.join(row) + '\n')
        except Exception:
            pass
        raise

    eval_duration = time.time() - eval_start_time
    log_f(f"  Trial {trial.number}: Success rate = {success_rate:.4f} (Eval took {eval_duration:.2f}s)")

    # Append trial results to CSV for later plotting/analysis
    try:
        write_header = not os.path.exists(TRIAL_RESULTS_CSV) or os.path.getsize(TRIAL_RESULTS_CSV) == 0
        with open(TRIAL_RESULTS_CSV, 'a', encoding='utf-8') as csvf:
            if write_header:
                header = ['trial_number','status','success_rate'] + [k for k in weights_for_search_script.keys()]
                csvf.write(','.join(header) + '\n')
            row = [str(trial.number), 'COMPLETE', f"{success_rate:.6f}"] + [str(weights_for_search_script[k]) for k in weights_for_search_script.keys()]
            csvf.write(','.join(row) + '\n')
    except Exception as e:
        log_f(f"Warning: could not write trial results CSV: {e}")

    return success_rate


# --- Main Execution ---
if __name__ == "__main__":
    try:
        LOG_FILE_HANDLE = open(os.path.join(RESULTS_DIR, f"optimization_run_{int(time.time())}.log"), 'w', encoding='utf-8')

        log_f("Starting Search Weight Optimization Script (updated)...")
        log_f(f"Optuna trials: {N_OPTUNA_TRIALS}, Test cases per run: {NUM_TEST_CASES} (dropout {DROPOUT_COUNT})")
        log_f(f"Top N results for pass: {TOP_N_RESULTS_FOR_PASS}")
        log_f(f"Output directory: {OUTPUT_DIR}")
        log_f(f"CSV file: {FDA_RECORDS_CSV}")
        log_f(f"Weight keys being considered: {WEIGHT_KEYS}")
        log_f("Excluded fields from optimization: ['summary', 'concepts']")

        # Generate test cases (uses cache + dropout strategy)
        TEST_CASES = generate_test_cases_from_csv(FDA_RECORDS_CSV, NUM_TEST_CASES, dropout_count=DROPOUT_COUNT)

        if not TEST_CASES or len(TEST_CASES) < NUM_TEST_CASES * 0.8:
            log_f("Failed to generate a sufficient number of test cases. Exiting.")
            sys.exit(1)

        log_f(f"Using {len(TEST_CASES)} test cases for optimization. (Cache: {BASE_TESTCASE_CACHE})\n")

        # Create/load Optuna study with a multivariate TPE sampler and pruning
        study_name = f"fda-search-weight-opt-{int(time.time())}"
        storage_name = f"sqlite:///{os.path.join(RESULTS_DIR, study_name)}.db"

        sampler = optuna.samplers.TPESampler(multivariate=True)
        pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=1, interval_steps=1)

        log_f(f"Creating/loading Optuna study: {study_name} (DB: {storage_name}) with multivariate TPE sampler and MedianPruner")
        # study = optuna.create_study(
        #     study_name=study_name,
        #     storage=storage_name,
        #     sampler=sampler,
        #     pruner=pruner,
        #     direction='maximize',
        #     load_if_exists=True
        # )

        # STORAGE_PATH = "sqlite:////Users/arun/Documents/fda-search/py_src/test_v2/logs/optuna_results.db"
        # STUDY_NAME = "fda_search_weight_opt"

        # study = optuna.create_study(
        #     study_name=STUDY_NAME,
        #     storage=STORAGE_PATH,
        #     load_if_exists=True,  # Resume if exists
        #     sampler=optuna.samplers.TPESampler(seed=42),
        #     direction="maximize"
        # )
        try:
            # study_name = "fda_search_weight_opt"  # Same as before
            storage = "sqlite:////Users/arun/Documents/fda-search/py_src/optimization_results/fda-search-weight-opt-1754753855.db"

            study = optuna.create_study(
                study_name="fda-search-weight-opt-1754753855",
                storage=storage,
                load_if_exists=True,
                sampler=optuna.samplers.TPESampler(multivariate=True),
                direction='maximize'	
            )


            completed_trials = len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])
            remaining_trials = N_OPTUNA_TRIALS - completed_trials

            if remaining_trials <= 0:
                log_f(f"Study '{study_name}' already has {completed_trials} completed trials (target was {N_OPTUNA_TRIALS}).")
            else:
                log_f(f"Study has {completed_trials} completed trials. Optimizing for {remaining_trials} more trials...")
                try:
                    study.optimize(objective, n_trials=remaining_trials, timeout=3600*24)
                except KeyboardInterrupt:
                    log_f("Optimization manually interrupted — proceeding to graph generation...")
                except Exception as e:
                    log_f(f"Optimization loop encountered an error: {e}")

        except Exception as e:
            log_f(f"Error creating or loading study: {e}")

        log_f("\n--- Optimization Finished! ---")
        try:
            best_trial = study.best_trial
            log_f(f"Best trial number: {best_trial.number}")
            log_f(f"Best success rate: {best_trial.value:.4f}")

            # Normalize and present best weights
            best_weights_raw = best_trial.params
            total_best_weight = sum(best_weights_raw.values()) if best_weights_raw else 0.0
            final_normalized_weights = {}
            if total_best_weight > 0:
                final_normalized_weights = {k.replace('w_',''): v/total_best_weight for k,v in best_weights_raw.items()}
            else:
                num_weights = len(best_weights_raw)
                if num_weights > 0:
                    final_normalized_weights = {k.replace('w_',''): 1.0/num_weights for k in best_weights_raw.keys()}

            
            log_f("\nRecommended 'weights' dictionary for create_embeddings.py (normalized):")
            log_f("current_weights = {")
            for field, weight_val in final_normalized_weights.items():
                log_f(f"    '{field}': {weight_val:.6f},")
            log_f("}")

        except ValueError:
            log_f("No best trial found. This might happen if no trials completed or all failed.")
        except Exception as e:
            log_f(f"Error displaying results: {e}")

        log_f(f"\nOptimization data saved in {storage_name}")
        log_f(f"Detailed logs saved in {os.path.join(RESULTS_DIR, 'optimization_run_*.log')}")
        log_f(f"Trial-level CSV: {TRIAL_RESULTS_CSV}")

        # --- Plotting ---
        try:
            if os.path.exists(TRIAL_RESULTS_CSV) and os.path.getsize(TRIAL_RESULTS_CSV) > 0:
                df = pd.read_csv(TRIAL_RESULTS_CSV)
                # Filter only COMPLETE rows for plotting param correlations
                df_complete = df[df['status'] == 'COMPLETE'].copy()
                if not df_complete.empty:
                    # Convert numeric columns
                    df_complete['trial_number'] = pd.to_numeric(df_complete['trial_number'], errors='coerce')
                    df_complete['success_rate'] = pd.to_numeric(df_complete['success_rate'], errors='coerce')

                    # Optimization history (success rate by trial)
                    plt.figure(figsize=(10,4))
                    plt.plot(df_complete['trial_number'], df_complete['success_rate'], marker='o', linestyle='-')
                    plt.xlabel('Trial Number')
                    plt.ylabel('Success Rate')
                    plt.title('Optimization History (Success Rate by Trial)')
                    hist_png = os.path.join(RESULTS_DIR, 'optimization_history.png')
                    plt.tight_layout()
                    plt.savefig(hist_png)
                    plt.close()
                    log_f(f"Saved optimization history plot: {hist_png}")

                    # Param scatter plots: each param vs success_rate
                    params = [c for c in df_complete.columns if c not in ['trial_number','status','success_rate']]
                    n_params = len(params)
                    cols = int(math.ceil(math.sqrt(n_params)))
                    rows = int(math.ceil(n_params / cols))
                    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*3))
                    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
                    for idx, p in enumerate(params):
                        axes[idx].scatter(df_complete[p].astype(float), df_complete['success_rate'].astype(float))
                        axes[idx].set_xlabel(p)
                        axes[idx].set_ylabel('success_rate')
                    # Hide any unused axes
                    for j in range(n_params, len(axes)):
                        fig.delaxes(axes[j])
                    plt.tight_layout()
                    scatter_png = os.path.join(RESULTS_DIR, 'param_scatter_plots.png')
                    plt.savefig(scatter_png)
                    plt.close()
                    log_f(f"Saved parameter vs success scatter plots: {scatter_png}")

                    # Parameter importance using Optuna's ELI5-style importance (if available)
                    try:
                        import optuna.importance as imp
                        importance = optuna.importance.get_param_importances(study)
                        if importance:
                            keys = list(importance.keys())
                            vals = [importance[k] for k in keys]
                            plt.figure(figsize=(8,4))
                            plt.bar(keys, vals)
                            plt.xticks(rotation=45, ha='right')
                            plt.title('Parameter Importance (Optuna)')
                            plt.tight_layout()
                            imp_png = os.path.join(RESULTS_DIR, 'param_importance.png')
                            plt.savefig(imp_png)
                            plt.close()
                            log_f(f"Saved parameter importance plot: {imp_png}")
                    except Exception:
                        log_f("Could not compute param importances with optuna on this environment.")
                else:
                    log_f("No COMPLETE trial rows found in trial CSV; skipping detailed plots.")
            else:
                log_f("No trial CSV found; skipping plotting.")
        except Exception as e:
            log_f(f"Error during plotting: {e}")

        log_f("The 'create_embeddings.py' script accepts a --weights_json argument to use custom weights for searching.")

    except Exception as e_main:
        if LOG_FILE_HANDLE:
            log_f(f"A critical error occurred in the main script: {e_main}")
        else:
            print(f"A critical error occurred before logging was set up: {e_main}")
    finally:
        if LOG_FILE_HANDLE:
            try:
                LOG_FILE_HANDLE.write("--- Script Execution Finished ---\n")
                LOG_FILE_HANDLE.flush()
                LOG_FILE_HANDLE.close()
            except Exception:
                pass
            LOG_FILE_HANDLE = None

