#!/usr/bin/env python3
"""
FDA search lambda grid search script
- Converted from multi-parameter Optuna optimization to single-parameter lambda grid search
- Uses grid search from 0.0 to 1.0 in 0.05 increments (21 values total)
- Maintains the same evaluation methodology: 50 test cases with 10 dropout between runs
- Logs trial-level results to CSV and produces plotting at the end
- Ranks lambda values by success rate at the end

NOTE: This script assumes `create_embeddings.py` is available and prints results in a parseable format.
"""

import csv
import subprocess
import json
import os
import re
import random
import ast
import sys
import time
import math
from typing import List, Dict, Any, Tuple
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
RESULTS_DIR = os.path.join(os.getcwd(), "lambda_grid_search_results")
os.makedirs(RESULTS_DIR, exist_ok=True)
BASE_TESTCASE_CACHE = os.path.join(RESULTS_DIR, "base_test_cases.json")
TRIAL_RESULTS_CSV = os.path.join(RESULTS_DIR, "lambda_grid_search_results.csv")

FDA_RECORDS_CSV = '/Users/arun/Documents/fda-search/py_src/fda_ai_records.csv'
CREATE_EMBEDDINGS_PY = '/Users/arun/Documents/fda-search/py_src/create_embeddings.py'
NUM_TEST_CASES = 50  # total test set size
DROPOUT_COUNT = 10   # number of items to replace per run (dropout)
TOP_N_RESULTS_FOR_PASS = 5 # Consider a test passed if the expected result is in top N
OUTPUT_DIR = '/Users/arun/Documents/fda-search/embedding_data_small'  # Output directory for search

# Lambda grid search configuration
LAMBDA_MIN = 0.0
LAMBDA_MAX = 1.0
LAMBDA_STEP = 0.05
LAMBDA_VALUES = [round(x, 2) for x in np.arange(LAMBDA_MIN, LAMBDA_MAX + LAMBDA_STEP, LAMBDA_STEP)]

# Fixed weights to use with create_embeddings.py
FIXED_WEIGHTS = {
    'keywords': 0.134207,
    'questions': 0.226103,
    'thesis': 0.094972,
    'search_boost': 0.029563,
    'query_match_1': 0.217395,
    'query_match_2': 0.241111,
    'query_match_3': 0.056650,
}

# Global variable for test cases
TEST_CASES: List[Dict[str, Any]] = []

# Global log file handle, to be set in main
LOG_FILE_HANDLE = None

# Ollama model used for query generation (if available)
OLLAMA_MODEL_NAME = "gemma3n:e2b"

# --- Helpers ---

def log_f(message: str):
    """Logs a message to both console and the global log file."""
    timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    formatted_message = f"[{timestamp}] {message}"
    print(formatted_message)
    if LOG_FILE_HANDLE:
        try:
            LOG_FILE_HANDLE.write(formatted_message + '\n')
            LOG_FILE_HANDLE.flush()
        except Exception:
            pass


def load_all_records(csv_path: str) -> List[Dict[str,str]]:
    all_records = []
    try:
        with open(csv_path, mode='r', encoding='utf-8') as infile:
            reader = csv.DictReader(infile)
            for row in reader:
                # Ensure essential fields for a test case are present
                if row.get('submission_number') and \
                   (row.get('thesis') or row.get('concepts') or \
                    row.get('summary_keywords') or row.get('device_model')):
                    all_records.append(row)
    except FileNotFoundError:
        log_f(f"Error: CSV file {csv_path} not found.")
    except Exception as e:
        log_f(f"Error reading CSV file {csv_path}: {e}")
    return all_records


def run_ollama_for_query(thesis: str, concepts_str: str, timeout: int = 30) -> str:
    prompt_for_cli = f'''You are an expert medical researcher. Based on the following thesis and key concepts from a medical device's FDA summary, generate a concise and clinically relevant search query that a clinician or researcher might use to find information about similar devices or technologies. Do NOT include anything about AI, ML, or what those mean. Only return the medical search query. Return only the query itself, without any preamble or explanation. Phrase it as a human might. Try to use words distinct from those in the thesis or key concepts.
Thesis: "{thesis}"
Key Concepts: "{concepts_str}"
Clinically Relevant Search Query:'''

    command = ["ollama", "run", OLLAMA_MODEL_NAME, prompt_for_cli]
    try:
        process = subprocess.run(command, capture_output=True, text=True, check=False, timeout=timeout)
        if process.returncode == 0:
            generated_query = process.stdout.strip()
            if generated_query:
                lines = generated_query.splitlines()
                # pick last non-empty, non-ollama banner line
                for line in reversed(lines):
                    clean_line = line.strip()
                    if clean_line and not clean_line.startswith('>>>') and OLLAMA_MODEL_NAME not in clean_line:
                        return clean_line
                return lines[-1].strip() if lines else ""
            return ""
        else:
            return ""
    except Exception:
        return ""


def generate_query_for_record(record: Dict[str,str]) -> Dict[str,str]:
    """Generates a query string for a CSV record (tries Ollama then fallbacks).
    Returns a dictionary with keys: 'query' and 'query_source'.
    """
    thesis = (record.get('thesis') or '').strip()
    concepts_str = (record.get('concepts') or '').strip()
    keywords_str = (record.get('summary_keywords') or '').strip()
    device_name = record.get('device_model') or 'Unknown Device'

    query = ""
    generated_by_ollama = False

    if (thesis or concepts_str):
        q = run_ollama_for_query(thesis, concepts_str)
        if q:
            query = q
            generated_by_ollama = True

    # Fallbacks
    if not query:
        if thesis and len(thesis.split()) >= 7:
            query = " ".join(thesis.split()[:random.randint(7,12)])
        elif thesis and len(thesis.split()) >= 5:
            query = " ".join(thesis.split()[:random.randint(5,7)])

    if (not query) or len(query.split()) < 4:
        if concepts_str:
            concept_list = [c.strip() for c in concepts_str.split(',') if c.strip()]
            if concept_list:
                num_to_sample = min(len(concept_list), random.randint(1,2))
                query = " ".join(random.sample(concept_list, num_to_sample))

    if (not query) or len(query.split()) < 3:
        if keywords_str:
            keyword_list = [k.strip() for k in keywords_str.split(',') if k.strip()]
            if keyword_list:
                num_to_sample = min(len(keyword_list), random.randint(2,3))
                query = " ".join(random.sample(keyword_list, num_to_sample))

    if not query:
        if device_name != 'Unknown Device' and device_name.strip():
            query = f"{device_name} medical device"
        else:
            query = "medical device information"

    return {'query': query, 'query_source': 'ollama' if generated_by_ollama else 'fallback'}


def generate_test_cases_from_csv(csv_path: str, num_cases: int, dropout_count: int = DROPOUT_COUNT) -> List[Dict[str,Any]]:
    """
    Behavior:
    - If a base test-case cache (BASE_TESTCASE_CACHE) exists: load it. That cache contains NUM_TEST_CASES items including pre-generated 'query'.
    - Perform a "dropout" of `dropout_count` indices chosen at random; for those indices, select replacement records from the CSV that are not already in the base set and generate queries for only those replacements.
    - If the cache does not exist, deterministically sample a base set (random.seed(42)) of `num_cases` records from the CSV, generate queries for all of them (slow), save cache, and return the set.
    """
    all_records = load_all_records(csv_path)
    if not all_records:
        log_f("No suitable records found in CSV to generate test cases.")
        return []

    # Load or create base cache
    base_records: List[Dict[str,Any]] = []
    if os.path.exists(BASE_TESTCASE_CACHE):
        try:
            with open(BASE_TESTCASE_CACHE, 'r', encoding='utf-8') as fh:
                base_records = json.load(fh)
            # Validate length
            if not isinstance(base_records, list) or len(base_records) != num_cases:
                log_f("Base cache exists but is invalid length. Recreating base cache.")
                base_records = []
        except Exception as e:
            log_f(f"Failed to read base cache: {e}. Recreating base cache.")
            base_records = []

    if not base_records:
        # Deterministic selection for base set to ensure repeatability across runs
        random.seed(42)
        if len(all_records) < num_cases:
            selected = all_records
            log_f(f"Warning: Requested {num_cases} test cases, but only {len(all_records)} suitable records available. Using all.")
        else:
            selected = random.sample(all_records, num_cases)

        # Generate queries for all base records (this may call Ollama)
        log_f(f"Generating queries for {len(selected)} base records...")
        base_records = []
        for i, rec in enumerate(selected):
            if i > 0 and i % 10 == 0:
                log_f(f"  Query generation progress: {i}/{len(selected)} ({i/len(selected)*100:.1f}%) - Last 10 took {time.time() - batch_start_time:.2f}s")
                batch_start_time = time.time()
            elif i == 0:
                batch_start_time = time.time()
            
            gen = generate_query_for_record(rec)
            stored = {
                'submission_number': rec['submission_number'],
                'device_model': rec.get('device_model', ''),
                'thesis': rec.get('thesis',''),
                'concepts': rec.get('concepts',''),
                'summary_keywords': rec.get('summary_keywords',''),
                'query': gen['query'],
                'query_source': gen['query_source']
            }
            base_records.append(stored)
            
            # Log a few example queries
            if i < 5:
                log_f(f"    Example query {i+1}: '{gen['query']}' (source: {gen['query_source']}) for {rec['submission_number']}")
        
        final_batch_time = time.time() - batch_start_time
        log_f(f"  Query generation completed: {len(selected)}/{len(selected)} (100%) - Final batch took {final_batch_time:.2f}s")

        # Save cache
        try:
            with open(BASE_TESTCASE_CACHE, 'w', encoding='utf-8') as fh:
                json.dump(base_records, fh, indent=2)
            log_f(f"Saved base test-case cache to {BASE_TESTCASE_CACHE}")
        except Exception as e:
            log_f(f"Warning: Could not save base test-case cache: {e}")

    # Now perform dropout replacements for this run
    if dropout_count <= 0:
        return base_records

    if dropout_count >= num_cases:
        log_f("Dropout count >= num_cases; regenerating entire set for this run.")
        # fallback: just regenerate entirely (non-deterministic)
        random.seed(None)
        selected = random.sample(all_records, num_cases) if len(all_records) >= num_cases else all_records
        run_records = []
        log_f(f"Generating queries for {len(selected)} records (full regeneration)...")
        for i, rec in enumerate(selected):
            if i > 0 and i % 10 == 0:
                log_f(f"  Query generation progress: {i}/{len(selected)} ({i/len(selected)*100:.1f}%)")
            
            gen = generate_query_for_record(rec)
            run_records.append({
                'submission_number': rec['submission_number'],
                'device_model': rec.get('device_model',''),
                'thesis': rec.get('thesis',''),
                'concepts': rec.get('concepts',''),
                'summary_keywords': rec.get('summary_keywords',''),
                'query': gen['query'],
                'query_source': gen['query_source']
            })
        log_f(f"  Query generation completed: {len(selected)}/{len(selected)} (100%)")
        return run_records

    # choose dropout indices non-deterministically
    random.seed(None)
    dropout_indices = random.sample(range(num_cases), dropout_count)

    # Build a set of submission_numbers currently in base to avoid picking duplicates
    current_submissions = {rec['submission_number'] for rec in base_records}

    # Pool of candidates for replacements: records not in current_submissions
    candidates = [r for r in all_records if r['submission_number'] not in current_submissions]
    if len(candidates) < dropout_count:
        log_f("Not enough candidate records to replace dropout_count items; reducing dropout count.")
        dropout_indices = dropout_indices[:max(0, len(candidates))]

    # Create run_records as a deep copy of base_records
    run_records = [dict(r) for r in base_records]

    if dropout_indices:
        log_f(f"Generating queries for {len(dropout_indices)} dropout replacements...")
        
    for idx, di in enumerate(dropout_indices):
        # Progress logging every 10 replacements
        if idx > 0 and idx % 10 == 0:
            log_f(f"  Dropout query generation progress: {idx}/{len(dropout_indices)} ({idx/len(dropout_indices)*100:.1f}%)")
        
        # Pick a random candidate record to replace
        if not candidates:
            log_f("No candidates left for replacement; stopping replacements early.")
            break
        replacement = random.choice(candidates)
        candidates.remove(replacement)  # avoid reuse
        gen = generate_query_for_record(replacement)
        run_records[di] = {
            'submission_number': replacement['submission_number'],
            'device_model': replacement.get('device_model',''),
            'thesis': replacement.get('thesis',''),
            'concepts': replacement.get('concepts',''),
            'summary_keywords': replacement.get('summary_keywords',''),
            'query': gen['query'],
            'query_source': gen['query_source']
        }
        
        # Log first few replacement examples
        if idx < 3:
            log_f(f"    Dropout example {idx+1}: '{gen['query']}' (source: {gen['query_source']}) for {replacement['submission_number']}")

    if dropout_indices:
        log_f(f"  Dropout query generation completed: {len(dropout_indices)}/{len(dropout_indices)} (100%)")

    replaced_count = len(dropout_indices)
    log_f(f"Using {num_cases} test cases with dropout: replaced {replaced_count} items this run.")
    return run_records


def run_search_for_optimization(query: str, lambda_value: float) -> List[Dict[str,Any]]:
    """
    Runs create_embeddings.py with --query for a given query, fixed weights, and specific lambda value,
    returning top N submission numbers with similarity.
    """
    # Use the fixed weights and pass lambda_val as a separate parameter
    weights_json = json.dumps(FIXED_WEIGHTS)
    
    # Call create_embeddings.py with both weights_json and lambda_val parameters
    command = ['python3', CREATE_EMBEDDINGS_PY, '--query', query, '--weights_json', weights_json, '--csv_path', FDA_RECORDS_CSV, '--output_dir', OUTPUT_DIR, '--lambda_val', str(lambda_value)]

    log_f(f"        Executing search command: {' '.join(command)}")
    search_start_time = time.time()
    
    try:
        process = subprocess.run(
            command,
            capture_output=True, text=True, check=False, timeout=60 # 1 min timeout per search
        )
        
        search_duration = time.time() - search_start_time
        log_f(f"        Search completed in {search_duration:.2f}s, return code: {process.returncode}")

        if process.returncode != 0:
            log_f(f"        ERROR: Command failed with stderr: {process.stderr}")
            return []

        # Parse the output from create_embeddings.py
        try:
            output_lines = process.stdout.strip().split('\n')
            log_f(f"        Parsing {len(output_lines)} output lines from create_embeddings.py")
            
            parsed_results = []
            results_started = False
            for line_num, line in enumerate(output_lines):
                line = line.strip()
                
                # Look for the results header
                if "top" in line.lower() and "results for" in line.lower():
                    results_started = True
                    log_f(f"        Found results start marker at line {line_num}: {line}")
                    continue
                
                # Parse result lines that start with a number (e.g., "1. Submission: K250650, Score: 0.9712")
                if results_started and line and line[0].isdigit() and line[1:3] == '. ':
                    try:
                        # Extract submission number and score from lines like:
                        # "1. Submission: K250650, Score: 0.9712"
                        if 'Submission:' in line and 'Score:' in line:
                            # Split by comma to get submission and score parts
                            parts = line.split(',')
                            
                            submission_number = None
                            similarity_score = None
                            
                            for part in parts:
                                part = part.strip()
                                if 'Submission:' in part:
                                    submission_number = part.split('Submission:')[1].strip()
                                elif 'Score:' in part:
                                    try:
                                        similarity_score = float(part.split('Score:')[1].strip())
                                    except Exception as e:
                                        log_f(f"        Warning: Could not parse score from '{part}': {e}")
                                        similarity_score = None
                            
                            if submission_number:
                                parsed_results.append({
                                    "submissionNumber": submission_number,
                                    "similarity": similarity_score
                                })
                                log_f(f"        Parsed result: {submission_number} with score {similarity_score}")
                    
                    except Exception as e:
                        log_f(f"        Warning: Could not parse result line '{line}': {e}")
                        continue
                
                # Stop parsing if we hit an empty line or end of results
                if results_started and (not line or line.startswith('Text (Index')):
                    # Skip text description lines, continue parsing
                    continue
                
                # Stop if we've found enough results or hit a clear end marker
                if results_started and len(parsed_results) >= TOP_N_RESULTS_FOR_PASS + 2:
                    break
            
            log_f(f"        Successfully parsed {len(parsed_results)} results")
            return parsed_results[:TOP_N_RESULTS_FOR_PASS + 2]
            
        except Exception as e:
            log_f(f"        ERROR: Failed to parse search output: {e}")
            log_f(f"        Raw stdout (first 1000 chars): {process.stdout[:1000]}")
            return []

    except subprocess.TimeoutExpired:
        search_duration = time.time() - search_start_time
        log_f(f"        ERROR: Search timed out after {search_duration:.2f}s")
        return []
    except Exception as e:
        search_duration = time.time() - search_start_time
        log_f(f"        ERROR: Search failed after {search_duration:.2f}s with exception: {e}")
        return []


def evaluate_search_performance_for_lambda(current_test_cases: List[Dict[str,Any]], lambda_value: float, trial_number: int) -> float:
    """
    Evaluates search performance for the current lambda value by running all test cases.
    Returns the success rate (0.0 to 1.0).
    """
    passed_count = 0
    if not current_test_cases:
        return 0.0

    total_cases = len(current_test_cases)
    log_f(f"    Detailed results for Lambda = {lambda_value:.2f} (Trial {trial_number}):")

    for i, case in enumerate(current_test_cases):
        query = case['query']
        expected_submission = case['submission_number']
        device_name = case.get('device_model', 'N/A')
        query_source = case.get('query_source', 'unknown')

        search_results_with_similarity = run_search_for_optimization(query, lambda_value)
        search_results_submissions_only = [res['submissionNumber'] for res in search_results_with_similarity if 'submissionNumber' in res]

        is_pass = False
        if search_results_submissions_only:
            is_pass = expected_submission in search_results_submissions_only[:TOP_N_RESULTS_FOR_PASS]
        if is_pass:
            passed_count += 1

        # Logging each test case briefly
        log_f(f"      Test Case {i+1}/{total_cases}: Query='{query}' (Source: {query_source}, Expected: {expected_submission} for '{device_name}')")
        if not search_results_with_similarity:
            log_f(f"        -> No results returned.")
        else:
            for rank, res_detail in enumerate(search_results_with_similarity[:TOP_N_RESULTS_FOR_PASS + 2]):
                sim_score_val = res_detail.get('similarity')
                sim_score_str = f"{sim_score_val:.4f}" if sim_score_val is not None else "N/A"
                sub_num = res_detail.get('submissionNumber', 'Error: No SubNum')
                log_f(f"        -> Rank {rank+1}: {sub_num} (Similarity: {sim_score_str})")
        log_f(f"        -> Outcome: {'PASS' if is_pass else 'FAIL'}\n")

    success_rate = passed_count / total_cases if total_cases else 0.0
    return success_rate


def run_lambda_grid_search() -> List[Tuple[float, float]]:
    """
    Runs grid search over all lambda values.
    Returns list of (lambda_value, success_rate) tuples.
    """
    results = []
    
    log_f(f"Starting grid search over {len(LAMBDA_VALUES)} lambda values: {LAMBDA_VALUES}")
    total_start_time = time.time()
    
    for trial_number, lambda_val in enumerate(LAMBDA_VALUES, 1):
        trial_start_time = time.time()
        log_f(f"\n{'='*80}")
        log_f(f"TRIAL {trial_number}/{len(LAMBDA_VALUES)}: Testing Lambda = {lambda_val:.2f}")
        log_f(f"{'='*80}")
        
        # Generate test cases with dropout for this trial
        log_f(f"  Step 1: Generating test cases (target: {NUM_TEST_CASES}, dropout: {DROPOUT_COUNT})")
        testcase_start_time = time.time()
        current_test_cases = generate_test_cases_from_csv(FDA_RECORDS_CSV, NUM_TEST_CASES, dropout_count=DROPOUT_COUNT)
        testcase_duration = time.time() - testcase_start_time
        
        if not current_test_cases:
            log_f(f"  ERROR: Failed to generate test cases for lambda = {lambda_val:.2f}, skipping trial...")
            results.append((lambda_val, 0.0))
            continue
        
        log_f(f"  Successfully generated {len(current_test_cases)} test cases in {testcase_duration:.2f}s")
        
        # Log some statistics about the test cases
        ollama_count = sum(1 for case in current_test_cases if case.get('query_source') == 'ollama')
        fallback_count = len(current_test_cases) - ollama_count
        log_f(f"  Test case sources: {ollama_count} from Ollama, {fallback_count} from fallback methods")
        
        # Show a few example queries
        log_f(f"  Sample queries for this trial:")
        for i, case in enumerate(current_test_cases[:3]):
            log_f(f"    Example {i+1}: '{case['query']}' -> {case['submission_number']} ({case.get('device_model', 'N/A')})")
        
        log_f(f"  Step 2: Running evaluation for lambda = {lambda_val:.2f}")
        eval_start_time = time.time()
        success_rate = evaluate_search_performance_for_lambda(current_test_cases, lambda_val, trial_number)
        eval_duration = time.time() - eval_start_time
        trial_duration = time.time() - trial_start_time
        
        log_f(f"  TRIAL {trial_number} COMPLETE:")
        log_f(f"    Lambda = {lambda_val:.2f}")
        log_f(f"    Success Rate = {success_rate:.4f} ({success_rate*100:.1f}%)")
        log_f(f"    Evaluation Time = {eval_duration:.2f}s")
        log_f(f"    Total Trial Time = {trial_duration:.2f}s")
        
        # Calculate ETA
        if trial_number > 1:
            avg_trial_time = (time.time() - total_start_time) / trial_number
            remaining_trials = len(LAMBDA_VALUES) - trial_number
            eta_seconds = remaining_trials * avg_trial_time
            eta_minutes = eta_seconds / 60
            log_f(f"    ETA for remaining {remaining_trials} trials: {eta_minutes:.1f} minutes")
        
        results.append((lambda_val, success_rate))
        
        # Append trial results to CSV
        try:
            write_header = not os.path.exists(TRIAL_RESULTS_CSV) or os.path.getsize(TRIAL_RESULTS_CSV) == 0
            with open(TRIAL_RESULTS_CSV, 'a', encoding='utf-8') as csvf:
                if write_header:
                    header = ['trial_number', 'lambda_value', 'success_rate', 'eval_duration', 'trial_duration']
                    csvf.write(','.join(header) + '\n')
                row = [str(trial_number), f"{lambda_val:.2f}", f"{success_rate:.6f}", f"{eval_duration:.2f}", f"{trial_duration:.2f}"]
                csvf.write(','.join(row) + '\n')
            log_f(f"  Results saved to CSV: {TRIAL_RESULTS_CSV}")
        except Exception as e:
            log_f(f"  WARNING: Could not write trial results CSV: {e}")
    
    total_duration = time.time() - total_start_time
    log_f(f"\n{'='*80}")
    log_f(f"GRID SEARCH COMPLETED!")
    log_f(f"Total runtime: {total_duration:.2f}s ({total_duration/60:.1f} minutes)")
    log_f(f"Average time per trial: {total_duration/len(LAMBDA_VALUES):.2f}s")
    log_f(f"{'='*80}")
    
    return results


# --- Main Execution ---
if __name__ == "__main__":
    try:
        LOG_FILE_HANDLE = open(os.path.join(RESULTS_DIR, f"lambda_grid_search_run_{int(time.time())}.log"), 'w', encoding='utf-8')

        log_f("="*80)
        log_f("FDA LAMBDA GRID SEARCH STARTING")
        log_f("="*80)
        log_f(f"Script started at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
        log_f(f"Results directory: {RESULTS_DIR}")
        log_f(f"Log file: {os.path.join(RESULTS_DIR, f'lambda_grid_search_run_{int(time.time())}.log')}")
        log_f("")
        log_f("CONFIGURATION:")
        log_f(f"  Lambda range: {LAMBDA_MIN} to {LAMBDA_MAX} in steps of {LAMBDA_STEP}")
        log_f(f"  Total lambda values to test: {len(LAMBDA_VALUES)}")
        log_f(f"  Lambda values: {LAMBDA_VALUES}")
        log_f(f"  Test cases per run: {NUM_TEST_CASES} (dropout {DROPOUT_COUNT})")
        log_f(f"  Top N results for pass: {TOP_N_RESULTS_FOR_PASS}")
        log_f(f"  Output directory: {OUTPUT_DIR}")
        log_f(f"  CSV file: {FDA_RECORDS_CSV}")
        log_f(f"  Create embeddings script: {CREATE_EMBEDDINGS_PY}")
        log_f("")
        log_f("FIXED WEIGHTS:")
        for key, value in FIXED_WEIGHTS.items():
            log_f(f"  {key}: {value:.6f}")
        log_f("")
        log_f("PATHS:")
        log_f(f"  Base test case cache: {BASE_TESTCASE_CACHE}")
        log_f(f"  Trial results CSV: {TRIAL_RESULTS_CSV}")
        log_f("")
        
        # Validate paths exist
        log_f("VALIDATING PATHS:")
        if os.path.exists(FDA_RECORDS_CSV):
            log_f(f"  ✓ FDA records CSV found: {FDA_RECORDS_CSV}")
        else:
            log_f(f"  ✗ FDA records CSV NOT FOUND: {FDA_RECORDS_CSV}")
        
        if os.path.exists(CREATE_EMBEDDINGS_PY):
            log_f(f"  ✓ Create embeddings script found: {CREATE_EMBEDDINGS_PY}")
        else:
            log_f(f"  ✗ Create embeddings script NOT FOUND: {CREATE_EMBEDDINGS_PY}")
        
        if os.path.exists(OUTPUT_DIR):
            log_f(f"  ✓ Output directory exists: {OUTPUT_DIR}")
        else:
            log_f(f"  ! Output directory will be created: {OUTPUT_DIR}")
        log_f("")
        
        # Show system info
        log_f("SYSTEM INFO:")
        import platform
        log_f(f"  Python version: {platform.python_version()}")
        log_f(f"  System: {platform.system()} {platform.release()}")
        log_f(f"  Current working directory: {os.getcwd()}")
        log_f("")

        # Run grid search
        grid_search_results = run_lambda_grid_search()

        log_f("="*80)
        log_f("GRID SEARCH RESULTS SUMMARY")
        log_f("="*80)
        
        # Sort results by success rate (descending)
        sorted_results = sorted(grid_search_results, key=lambda x: x[1], reverse=True)
        
        log_f(f"LAMBDA RANKING (by Success Rate):")
        log_f(f"{'Rank':<6} {'Lambda':<8} {'Success Rate':<12} {'Percentage':<10}")
        log_f(f"{'-'*6} {'-'*8} {'-'*12} {'-'*10}")
        for rank, (lambda_val, success_rate) in enumerate(sorted_results, 1):
            log_f(f"{rank:<6} {lambda_val:<8.2f} {success_rate:<12.6f} {success_rate*100:<10.1f}%")
        
        best_lambda, best_success_rate = sorted_results[0]
        worst_lambda, worst_success_rate = sorted_results[-1]
        
        log_f("")
        log_f("KEY STATISTICS:")
        log_f(f"  Best Lambda: {best_lambda:.2f} with Success Rate: {best_success_rate:.6f} ({best_success_rate*100:.1f}%)")
        log_f(f"  Worst Lambda: {worst_lambda:.2f} with Success Rate: {worst_success_rate:.6f} ({worst_success_rate*100:.1f}%)")
        log_f(f"  Performance Range: {(best_success_rate - worst_success_rate)*100:.1f} percentage points")
        
        # Calculate some stats
        success_rates = [rate for _, rate in grid_search_results]
        avg_success_rate = sum(success_rates) / len(success_rates)
        log_f(f"  Average Success Rate: {avg_success_rate:.6f} ({avg_success_rate*100:.1f}%)")
        
        # Show top 5 and bottom 5
        log_f("")
        log_f("TOP 5 PERFORMERS:")
        for i, (lambda_val, success_rate) in enumerate(sorted_results[:5]):
            log_f(f"  {i+1}. Lambda {lambda_val:.2f}: {success_rate:.6f} ({success_rate*100:.1f}%)")
        
        log_f("")
        log_f("BOTTOM 5 PERFORMERS:")
        for i, (lambda_val, success_rate) in enumerate(sorted_results[-5:]):
            log_f(f"  {len(sorted_results)-4+i}. Lambda {lambda_val:.2f}: {success_rate:.6f} ({success_rate*100:.1f}%)")

        log_f("")
        log_f("FILES CREATED:")
        log_f(f"  Grid search data: {TRIAL_RESULTS_CSV}")
        log_f(f"  Detailed logs: {os.path.join(RESULTS_DIR, 'lambda_grid_search_run_*.log')}")
        log_f(f"  Test case cache: {BASE_TESTCASE_CACHE}")

        # --- Plotting ---
        try:
            if os.path.exists(TRIAL_RESULTS_CSV) and os.path.getsize(TRIAL_RESULTS_CSV) > 0:
                df = pd.read_csv(TRIAL_RESULTS_CSV)
                if not df.empty:
                    # Convert numeric columns
                    df['lambda_value'] = pd.to_numeric(df['lambda_value'], errors='coerce')
                    df['success_rate'] = pd.to_numeric(df['success_rate'], errors='coerce')

                    # Lambda vs Success Rate plot
                    plt.figure(figsize=(12, 6))
                    plt.plot(df['lambda_value'], df['success_rate'], marker='o', linestyle='-', linewidth=2, markersize=8)
                    plt.xlabel('Lambda Value')
                    plt.ylabel('Success Rate')
                    plt.title('Lambda Grid Search Results')
                    plt.grid(True, alpha=0.3)
                    
                    # Highlight the best lambda
                    best_idx = df['success_rate'].idxmax()
                    plt.plot(df.loc[best_idx, 'lambda_value'], df.loc[best_idx, 'success_rate'], 
                            marker='*', markersize=15, color='red', label=f'Best: λ={best_lambda:.2f}')
                    plt.legend()
                    
                    results_png = os.path.join(RESULTS_DIR, 'lambda_grid_search_results.png')
                    plt.tight_layout()
                    plt.savefig(results_png, dpi=300)
                    plt.close()
                    log_f(f"Saved grid search results plot: {results_png}")

                    # Bar chart of top 10 lambdas
                    top_10 = df.nlargest(10, 'success_rate')
                    plt.figure(figsize=(10, 6))
                    bars = plt.bar(range(len(top_10)), top_10['success_rate'], 
                                  color=['red' if i == 0 else 'skyblue' for i in range(len(top_10))])
                    plt.xlabel('Lambda Rank')
                    plt.ylabel('Success Rate')
                    plt.title('Top 10 Lambda Values by Success Rate')
                    plt.xticks(range(len(top_10)), [f'λ={x:.2f}' for x in top_10['lambda_value']], rotation=45)
                    
                    # Add value labels on bars
                    for bar, val in zip(bars, top_10['success_rate']):
                        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, 
                                f'{val:.3f}', ha='center', va='bottom', fontsize=9)
                    
                    plt.tight_layout()
                    top10_png = os.path.join(RESULTS_DIR, 'top_10_lambdas.png')
                    plt.savefig(top10_png, dpi=300)
                    plt.close()
                    log_f(f"Saved top 10 lambdas plot: {top10_png}")

                else:
                    log_f("No data found in trial CSV; skipping plotting.")
            else:
                log_f("No trial CSV found; skipping plotting.")
        except Exception as e:
            log_f(f"Error during plotting: {e}")

    except Exception as e_main:
        if LOG_FILE_HANDLE:
            log_f(f"A critical error occurred in the main script: {e_main}")
        else:
            print(f"A critical error occurred before logging was set up: {e_main}")
    finally:
        if LOG_FILE_HANDLE:
            try:
                LOG_FILE_HANDLE.write("--- Script Execution Finished ---\n")
                LOG_FILE_HANDLE.flush()
                LOG_FILE_HANDLE.close()
            except Exception:
                pass
            LOG_FILE_HANDLE = None