import json
import sys
import os
import argparse
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from datetime import datetime
import faiss
import pickle
from collections import defaultdict

# python3 eval_hybrid_v2.py --output_dir /Users/arun/Documents/fda-search/embedding_data_small --query_csv /Users/arun/Documents/fda-search/py_src/test_v2/test_query_db.csv

"""
Why Your Embedding Search Might Score Lower But Be "Better"

Semantic Understanding: Your embedding search is likely finding devices that treat the same condition with similar modalities - which is actually what a real user would want! If someone searches "lung cancer CT", they probably want to see ALL relevant lung cancer imaging devices, not just one specific submission.
The "Exact Match" Problem: Your professor's evaluation only gives credit for finding the exact submission number that was in the test case. But if there are 5 FDA devices for "lung cancer CT screening", and your system returns 4 of them (including the "correct" one at position 3), you get penalized even though that's a great result for a user.

"""

# Configuration Parameters
DEFAULT_K_VALUES = [1, 3, 5, 10, 25]  # Default K values for Hit@K evaluation

# Add the parent directory to the Python path to allow for imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from create_embeddings import search_hybrid, MODEL_NAME
from bm25_baseline import BM25Baseline

def load_test_queries(filepath):
    """Loads the test query dataset from a CSV file."""
    df = pd.read_csv(filepath)
    # Create the queries based on the columns
    test_cases = []
    
    for _, row in df.iterrows():
        # Extract the relevant information
        submission_number = str(row['Number'])  # Convert to string to match other parts of code
        disease = str(row['disease_query']) if pd.notna(row['disease_query']) else ""
        modality = str(row['modality_query']) if pd.notna(row['modality_query']) else ""
        
        # Skip if essential data is missing
        if not disease or disease.lower() == 'nan':
            continue
            
        # Create two types of queries
        # 1. Disease + Modality (if modality exists)
        if modality and modality.lower() != 'nan':
            disease_modality_query = f"{disease} {modality}"
            test_cases.append({
                'query': disease_modality_query,
                'expected_submission': submission_number,
                'query_type': 'disease_modality'
            })
        
        # 2. Disease only
        test_cases.append({
            'query': disease,
            'expected_submission': submission_number,
            'query_type': 'disease_only'
        })
    
    return test_cases

def calculate_hit_at_k(retrieved_submissions, expected_submission, k):
    """
    Calculates Hit@K - returns 1 if the expected submission is in top K, 0 otherwise.
    """
    top_k_submissions = retrieved_submissions[:k]
    return 1.0 if expected_submission in top_k_submissions else 0.0

def reciprocal_rank(retrieved_submissions, target_submission):
    """Return 1/rank of first relevant result, 0 if not found"""
    for i, sub in enumerate(retrieved_submissions):
        if sub == target_submission:
            return 1.0 / (i + 1)
    return 0.0

def find_rank_position(retrieved_submissions, target_submission):
    """Return the 1-based rank position of target submission, or -1 if not found"""
    for i, sub in enumerate(retrieved_submissions):
        if sub == target_submission:
            return i + 1
    return -1

def combine_search_results(embedding_results, bm25_results, embedding_weight=0.6, bm25_weight=0.4, max_results=100):
    """
    Combines embedding and BM25 search results using weighted scoring.
    
    Args:
        embedding_results: List of dicts with 'submission_number' and 'hybrid_similarity' keys
        bm25_results: List of tuples (submission_number, bm25_score) from BM25
        embedding_weight: Weight for embedding scores (default 0.6)
        bm25_weight: Weight for BM25 scores (default 0.4)
        max_results: Maximum number of results to return
    
    Returns:
        List of submission numbers ranked by combined score
    """
    # Handle embedding results using hybrid_similarity scores
    embedding_scores = {}
    if embedding_results:
        if embedding_results and isinstance(embedding_results[0], dict):
            # Use hybrid_similarity scores from search_hybrid results
            if 'hybrid_similarity' in embedding_results[0]:
                # Normalize hybrid_similarity scores to 0-1 range for fair combination
                max_emb_score = max(res['hybrid_similarity'] for res in embedding_results) if embedding_results else 1.0
                min_emb_score = min(res['hybrid_similarity'] for res in embedding_results) if embedding_results else 0.0
                score_range = max_emb_score - min_emb_score if max_emb_score != min_emb_score else 1.0
                
                for res in embedding_results:
                    sub_num = str(res['submission_number'])
                    if score_range > 0:
                        normalized_score = (res['hybrid_similarity'] - min_emb_score) / score_range
                    else:
                        normalized_score = 1.0  # All scores are the same
                    embedding_scores[sub_num] = normalized_score
            else:
                # Fall back to rank-based scoring if hybrid_similarity not found
                for i, res in enumerate(embedding_results):
                    sub_num = str(res['submission_number'])
                    embedding_scores[sub_num] = (len(embedding_results) - i) / len(embedding_results)
        else:
            # If embedding_results is just a list of submission numbers
            for i, sub_num in enumerate(embedding_results):
                sub_num = str(sub_num)
                embedding_scores[sub_num] = (len(embedding_results) - i) / len(embedding_results)
    
    # Handle BM25 results - they come as tuples (submission_number, score)
    bm25_scores = {}
    if bm25_results:
        # Extract BM25 scores and normalize them
        bm25_score_values = [score for _, score in bm25_results]
        if bm25_score_values:
            max_bm25_score = max(bm25_score_values)
            min_bm25_score = min(bm25_score_values)
            bm25_score_range = max_bm25_score - min_bm25_score if max_bm25_score != min_bm25_score else 1.0
            
            for sub_num, score in bm25_results:
                sub_num = str(sub_num)
                if bm25_score_range > 0:
                    normalized_score = (score - min_bm25_score) / bm25_score_range
                else:
                    normalized_score = 1.0 if score > 0 else 0.0  # Handle case where all scores are the same
                bm25_scores[sub_num] = normalized_score
    
    # Combine scores
    combined_scores = defaultdict(float)
    all_submissions = set(embedding_scores.keys()) | set(bm25_scores.keys())
    
    for sub_num in all_submissions:
        emb_score = embedding_scores.get(sub_num, 0.0)
        bm25_score = bm25_scores.get(sub_num, 0.0)
        combined_scores[sub_num] = (embedding_weight * emb_score) + (bm25_weight * bm25_score)
    
    # Sort by combined score (descending) and return submission numbers
    sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
    return [sub_num for sub_num, score in sorted_results[:max_results]]

def run_database_evaluation(test_cases, model, all_texts, all_indexes, bm25_baseline, k_values, logger, custom_weights=None, allowed_submissions=None):
    """
    Runs the database evaluation using Hit@K metrics and additional position metrics.
    """
    # Initialize results storage
    results = {
        'embedding_search': {k: [] for k in k_values},
        'bm25_baseline': {k: [] for k in k_values},
        'hybrid_search': {k: [] for k in k_values}  # New hybrid search results
    }
    
    # Storage for position metrics
    embedding_positions = []  # Stores rank positions for embedding search
    bm25_positions = []       # Stores rank positions for BM25 search
    hybrid_positions = []     # Stores rank positions for hybrid search
    
    query_type_results = {
        'disease_modality': {
            'embedding_search': {k: [] for k in k_values}, 
            'bm25_baseline': {k: [] for k in k_values},
            'hybrid_search': {k: [] for k in k_values}
        },
        'disease_only': {
            'embedding_search': {k: [] for k in k_values}, 
            'bm25_baseline': {k: [] for k in k_values},
            'hybrid_search': {k: [] for k in k_values}
        }
    }
    
    # Position tracking by query type
    query_type_positions = {
        'disease_modality': {'embedding': [], 'bm25': [], 'hybrid': []},
        'disease_only': {'embedding': [], 'bm25': [], 'hybrid': []}
    }
    
    skipped_queries = []
    processed_queries = 0
    
    # Get the set of available submission numbers for validation
    available_submissions = set(str(sub) for sub in all_texts['submission_numbers'])
    if allowed_submissions is not None:
        available_submissions &= allowed_submissions

    logger(f"Starting evaluation with {len(test_cases)} test cases...")
    logger(f"Available submissions in database: {len(available_submissions)}")
    logger("")
    
    for i, case in enumerate(test_cases):
        query = case['query']
        expected_submission = str(case['expected_submission'])
        query_type = case['query_type']
        
        if i % 50 == 0:  # Progress indicator
            logger(f"Processing query {i+1}/{len(test_cases)}: '{query}'")
        
        # Check if expected submission exists in our database
        if expected_submission not in available_submissions:
            logger(f"Skipping query '{query}' - expected submission {expected_submission} not in database")
            skipped_queries.append({
                'query': query,
                'expected_submission': expected_submission,
                'reason': 'submission_not_in_database'
            })
            continue
        
        # Get results from both search systems with full result set (no truncation)
        full_result_k = 1500  # Larger than total submissions to get everything
        
        try:
            embedding_results = search_hybrid(query, model, all_texts, all_indexes, top_k=full_result_k, custom_weights=custom_weights)
            embedding_submissions = [str(res['submission_number']) for res in embedding_results]

            embedding_submissions = [s for s in embedding_submissions if s in available_submissions]
        except Exception as e:
            logger(f"Error in embedding search for query '{query}': {e}")
            skipped_queries.append({
                'query': query,
                'expected_submission': expected_submission,
                'reason': 'embedding_search_error'
            })
            continue
        
        try:
            bm25_results_with_scores = bm25_baseline.search(query, top_k=full_result_k)
            bm25_submissions = [str(result[0]) for result in bm25_results_with_scores]
            bm25_submissions = [s for s in bm25_submissions if s in available_submissions]
        except Exception as e:
            logger(f"Error in BM25 search for query '{query}': {e}")
            skipped_queries.append({
                'query': query,
                'expected_submission': expected_submission,
                'reason': 'bm25_search_error'
            })
            continue
        
        # Create hybrid search results
        try:
            hybrid_submissions = combine_search_results(
                embedding_results, bm25_results_with_scores, 
                embedding_weight=0.6, bm25_weight=0.4, 
                max_results=full_result_k
            )
            hybrid_submissions  = [s for s in hybrid_submissions   if s in available_submissions]
        except Exception as e:
            logger(f"Error in hybrid search for query '{query}': {e}")
            skipped_queries.append({
                'query': query,
                'expected_submission': expected_submission,
                'reason': 'hybrid_search_error'
            })
            continue
        
        # Find positions in each search method (ONCE, with the large result set)
        embedding_pos = find_rank_position(embedding_submissions, expected_submission)
        bm25_pos = find_rank_position(bm25_submissions, expected_submission)
        hybrid_pos = find_rank_position(hybrid_submissions, expected_submission)
        
        # Store position data (only if found)
        if embedding_pos != -1:
            embedding_positions.append(embedding_pos)
            query_type_positions[query_type]['embedding'].append(embedding_pos)
        
        if bm25_pos != -1:
            bm25_positions.append(bm25_pos)
            query_type_positions[query_type]['bm25'].append(bm25_pos)
            
        if hybrid_pos != -1:
            hybrid_positions.append(hybrid_pos)
            query_type_positions[query_type]['hybrid'].append(hybrid_pos)
        
        # Calculate Hit@K for different K values (using truncated results)
        for k in k_values:
            embedding_hit = calculate_hit_at_k(embedding_submissions, expected_submission, k)
            bm25_hit = calculate_hit_at_k(bm25_submissions, expected_submission, k)
            hybrid_hit = calculate_hit_at_k(hybrid_submissions, expected_submission, k)
            
            # Store overall results
            results['embedding_search'][k].append(embedding_hit)
            results['bm25_baseline'][k].append(bm25_hit)
            results['hybrid_search'][k].append(hybrid_hit)
            
            # Store results by query type
            query_type_results[query_type]['embedding_search'][k].append(embedding_hit)
            query_type_results[query_type]['bm25_baseline'][k].append(bm25_hit)
            query_type_results[query_type]['hybrid_search'][k].append(hybrid_hit)
        
        processed_queries += 1
        
        # Log detailed results for first few queries or when there's a miss at K=1
        if i < 5 or (calculate_hit_at_k(embedding_submissions, expected_submission, 1) == 0 and 
                     calculate_hit_at_k(bm25_submissions, expected_submission, 1) == 0 and
                     calculate_hit_at_k(hybrid_submissions, expected_submission, 1) == 0):
            logger(f"Query: '{query}' (Expected: {expected_submission})")
            logger(f"  Embedding top 5: {embedding_submissions[:5]} (pos: {embedding_pos if embedding_pos != -1 else 'not found'})")
            logger(f"  BM25 top 5: {bm25_submissions[:5]} (pos: {bm25_pos if bm25_pos != -1 else 'not found'})")
            logger(f"  Hybrid top 5: {hybrid_submissions[:5]} (pos: {hybrid_pos if hybrid_pos != -1 else 'not found'})")
            for k in k_values:
                emb_hit = calculate_hit_at_k(embedding_submissions, expected_submission, k)
                bm25_hit = calculate_hit_at_k(bm25_submissions, expected_submission, k)
                hybrid_hit = calculate_hit_at_k(hybrid_submissions, expected_submission, k)
                logger(f"  Hit@{k} - Embedding: {emb_hit}, BM25: {bm25_hit}, Hybrid: {hybrid_hit}")
            logger("")
    
    # Calculate and display results
    logger("="*80)
    logger("EVALUATION RESULTS")
    logger("="*80)
    
    logger(f"Total test cases: {len(test_cases)}")
    logger(f"Successfully processed: {processed_queries}")
    logger(f"Skipped: {len(skipped_queries)}")
    logger("")
    
    # Overall results
    logger("Overall Hit@K Results:")
    logger("-" * 60)
    for k in k_values:
        embedding_accuracy = np.mean(results['embedding_search'][k]) if results['embedding_search'][k] else 0
        bm25_accuracy = np.mean(results['bm25_baseline'][k]) if results['bm25_baseline'][k] else 0
        hybrid_accuracy = np.mean(results['hybrid_search'][k]) if results['hybrid_search'][k] else 0
        
        logger(f"Hit@{k}:")
        logger(f"  Embedding Search: {embedding_accuracy:.4f} ({embedding_accuracy*100:.2f}%)")
        logger(f"  BM25 Baseline:    {bm25_accuracy:.4f} ({bm25_accuracy*100:.2f}%)")
        logger(f"  Hybrid Search:    {hybrid_accuracy:.4f} ({hybrid_accuracy*100:.2f}%)")
        
        # Find the best performing method
        best_score = max(embedding_accuracy, bm25_accuracy, hybrid_accuracy)
        if hybrid_accuracy == best_score:
            logger(f"  → Hybrid search wins!")
        elif embedding_accuracy == best_score:
            logger(f"  → Embedding search wins!")
        else:
            logger(f"  → BM25 search wins!")
        logger("")
    
    # Position Analysis
    logger("Rank Position Analysis:")
    logger("-" * 60)
    
    if embedding_positions:
        avg_emb_pos = np.mean(embedding_positions)
        median_emb_pos = np.median(embedding_positions)
        min_emb_pos = np.min(embedding_positions)
        max_emb_pos = np.max(embedding_positions)
        stdev = np.std(embedding_positions)
        logger(f"Embedding Search:")
        logger(f"  Average position: {avg_emb_pos:.2f}")
        logger(f"  Median position: {median_emb_pos:.1f}")
        logger(f"  Min position: {min_emb_pos}")
        logger(f"  Max position: {max_emb_pos}")
        logger(f"  Stdev: {stdev}")

        logger(f"  Found in results: {len(embedding_positions)}/{processed_queries} ({len(embedding_positions)/processed_queries*100:.1f}%)")
    
    if bm25_positions:
        avg_bm25_pos = np.mean(bm25_positions)
        median_bm25_pos = np.median(bm25_positions)
        min_bm25_pos = np.min(bm25_positions)
        max_bm25_pos = np.max(bm25_positions)
        stdev = np.std(bm25_positions)

        logger(f"BM25 Search:")
        logger(f"  Average position: {avg_bm25_pos:.2f}")
        logger(f"  Median position: {median_bm25_pos:.1f}")
        logger(f"  Min position: {min_bm25_pos}")
        logger(f"  Max position: {max_bm25_pos}")
        logger(f"  Stdev: {stdev}")
        logger(f"  Found in results: {len(bm25_positions)}/{processed_queries} ({len(bm25_positions)/processed_queries*100:.1f}%)")
    
    if hybrid_positions:
        avg_hybrid_pos = np.mean(hybrid_positions)
        median_hybrid_pos = np.median(hybrid_positions)
        min_hybrid_pos = np.min(hybrid_positions)
        max_hybrid_pos = np.max(hybrid_positions)
        stdev = np.std(hybrid_positions)

        logger(f"Hybrid Search:")
        logger(f"  Average position: {avg_hybrid_pos:.2f}")
        logger(f"  Median position: {median_hybrid_pos:.1f}")
        logger(f"  Min position: {min_hybrid_pos}")
        logger(f"  Max position: {max_hybrid_pos}")
        logger(f"  Stdev: {stdev}")
        logger(f"  Found in results: {len(hybrid_positions)}/{processed_queries} ({len(hybrid_positions)/processed_queries*100:.1f}%)")
    
    logger("")
    
    # Results by query type
    logger("Results by Query Type:")
    logger("-" * 60)
    for query_type in ['disease_modality', 'disease_only']:
        type_name = "Disease + Modality" if query_type == 'disease_modality' else "Disease Only"
        type_count = len(query_type_results[query_type]['embedding_search'][k_values[0]])
        
        if type_count == 0:
            continue
            
        logger(f"{type_name} Queries (n={type_count}):")
        for k in k_values:
            embedding_accuracy = np.mean(query_type_results[query_type]['embedding_search'][k])
            bm25_accuracy = np.mean(query_type_results[query_type]['bm25_baseline'][k])
            hybrid_accuracy = np.mean(query_type_results[query_type]['hybrid_search'][k])
            
            logger(f"  Hit@{k}:")
            logger(f"    Embedding: {embedding_accuracy:.4f} ({embedding_accuracy*100:.2f}%)")
            logger(f"    BM25:      {bm25_accuracy:.4f} ({bm25_accuracy*100:.2f}%)")
            logger(f"    Hybrid:    {hybrid_accuracy:.4f} ({hybrid_accuracy*100:.2f}%)")
        
        # Enhanced position analysis by query type with all statistics
        logger(f"  Position Statistics:")
        
        # Embedding statistics
        if query_type_positions[query_type]['embedding']:
            positions = query_type_positions[query_type]['embedding']
            avg_pos = np.mean(positions)
            median_pos = np.median(positions)
            min_pos = np.min(positions)
            max_pos = np.max(positions)
            stdev_pos = np.std(positions)
            found_count = len(positions)
            logger(f"    Embedding: avg={avg_pos:.2f}, median={median_pos:.1f}, min={min_pos}, max={max_pos}, stdev={stdev_pos:.2f}")
            logger(f"               found={found_count}/{type_count} ({found_count/type_count*100:.1f}%)")
        else:
            logger(f"    Embedding: No results found for this query type")
            
        # BM25 statistics  
        if query_type_positions[query_type]['bm25']:
            positions = query_type_positions[query_type]['bm25']
            avg_pos = np.mean(positions)
            median_pos = np.median(positions)
            min_pos = np.min(positions)
            max_pos = np.max(positions)
            stdev_pos = np.std(positions)
            found_count = len(positions)
            logger(f"    BM25:      avg={avg_pos:.2f}, median={median_pos:.1f}, min={min_pos}, max={max_pos}, stdev={stdev_pos:.2f}")
            logger(f"               found={found_count}/{type_count} ({found_count/type_count*100:.1f}%)")
        else:
            logger(f"    BM25:      No results found for this query type")
            
        # Hybrid statistics
        if query_type_positions[query_type]['hybrid']:
            positions = query_type_positions[query_type]['hybrid']
            avg_pos = np.mean(positions)
            median_pos = np.median(positions)
            min_pos = np.min(positions)
            max_pos = np.max(positions)
            stdev_pos = np.std(positions)
            found_count = len(positions)
            logger(f"    Hybrid:    avg={avg_pos:.2f}, median={median_pos:.1f}, min={min_pos}, max={max_pos}, stdev={stdev_pos:.2f}")
            logger(f"               found={found_count}/{type_count} ({found_count/type_count*100:.1f}%)")
        else:
            logger(f"    Hybrid:    No results found for this query type")
            
        logger("")
    
    # Final recommendation
    logger("CONCLUSION:")
    logger("-" * 60)
    best_k = 1  # Focus on Hit@1 as primary metric
    embedding_hit1 = np.mean(results['embedding_search'][best_k]) if results['embedding_search'][best_k] else 0
    bm25_hit1 = np.mean(results['bm25_baseline'][best_k]) if results['bm25_baseline'][best_k] else 0
    hybrid_hit1 = np.mean(results['hybrid_search'][best_k]) if results['hybrid_search'][best_k] else 0
    
    best_method = "Hybrid"
    best_score = hybrid_hit1
    if embedding_hit1 > hybrid_hit1 and embedding_hit1 > bm25_hit1:
        best_method = "Embedding"
        best_score = embedding_hit1
    elif bm25_hit1 > hybrid_hit1 and bm25_hit1 > embedding_hit1:
        best_method = "BM25"
        best_score = bm25_hit1
    
    logger(f"Best performing method: {best_method} Search with Hit@{best_k} = {best_score:.4f}")
    logger(f"This means the correct FDA submission is found in the top result {best_score*100:.1f}% of the time.")
    
    if hybrid_hit1 > max(embedding_hit1, bm25_hit1):
        improvement = hybrid_hit1 - max(embedding_hit1, bm25_hit1)
        logger(f"Hybrid search improves over the best individual method by {improvement:.4f} ({improvement*100:.2f} percentage points).")

def main():
    parser = argparse.ArgumentParser(description="Run FDA database query evaluation using Hit@K metrics with hybrid search.")
    parser.add_argument('--query_csv', type=str, default='/Users/arun/Documents/fda-search/py_src/test_v2/test_query_db.csv', 
                       help="Path to the CSV file containing test queries.")
    parser.add_argument('--output_dir', type=str, required=True, 
                       help="Directory where the FAISS indexes and text data are stored.")
    parser.add_argument('--k_values', type=int, nargs='+', default=DEFAULT_K_VALUES, 
                       help=f"K values to evaluate Hit@K (default: {DEFAULT_K_VALUES})")
    parser.add_argument('--weights_json', type=str, 
                       help="JSON string of weights to use for search_hybrid. Example: '{\"summary\": 0.1, \"keywords\": 0.2, ...}'")

    args = parser.parse_args()

    # Parse custom weights if provided
    passed_weights = None
    if args.weights_json:
        try:
            passed_weights = json.loads(args.weights_json)
            print(f"Using custom weights: {passed_weights}")
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON string passed to --weights_json: {args.weights_json}")
            print("Using default weights.")

    # Setup logging
    log_dir = 'logs'
    os.makedirs(log_dir, exist_ok=True)
    log_filename = f"enhanced_fda_evaluation_{datetime.now().strftime('%Y%m%d-%H%M%S')}.log"
    log_filepath = os.path.join(log_dir, log_filename)

    log_file = open(log_filepath, 'w')

    def logger(message):
        """Prints to console and writes to log file."""
        print(message)
        log_file.write(message + '\n')

    # Load model
    logger("Loading Sentence Transformer model...")
    model = SentenceTransformer(MODEL_NAME)
    logger("Model loaded.")

    # Load all indexes and texts
    all_loaded_texts = {}
    all_loaded_indexes = {}

    fields_to_load = {
        "summary": ('summary_index.faiss', 'summary_texts.pkl'),
        "keywords": ('keywords_index.faiss', 'keywords_texts.pkl'),
        "questions": ('questions_index.faiss', 'questions_texts.pkl'),
        "concepts": ('concepts_index.faiss', 'concepts_texts.pkl'),
        "thesis": ('thesis_index.faiss', 'thesis_texts.pkl'),
        "search_boost": ('search_boost_index.faiss', 'search_boost_texts.pkl'),
        "query_match_1": ('query_match_1_index.faiss', 'query_match_1_texts.pkl'),
        "query_match_2": ('query_match_2_index.faiss', 'query_match_2_texts.pkl'),
        "query_match_3": ('query_match_3_index.faiss', 'query_match_3_texts.pkl'),
    }

    for name, (index_file, texts_file) in fields_to_load.items():
        index_path = os.path.join(args.output_dir, index_file)
        texts_path = os.path.join(args.output_dir, texts_file)
        if os.path.exists(index_path) and os.path.exists(texts_path):
            all_loaded_indexes[name] = faiss.read_index(index_path)
            with open(texts_path, 'rb') as f:
                all_loaded_texts[name] = pickle.load(f)
        else:
            logger(f"Warning: {name.capitalize()} index/texts not found. Skipping.")
            all_loaded_indexes[name] = None
            all_loaded_texts[name] = []

    # Load submission numbers
    submission_numbers_file = os.path.join(args.output_dir, 'submission_numbers.pkl')
    if os.path.exists(submission_numbers_file):
        with open(submission_numbers_file, 'rb') as f:
            all_loaded_texts['submission_numbers'] = pickle.load(f)
    else:
        logger("Error: Submission numbers file not found. Cannot run evaluation.")
        return

    # Prepare BM25 baseline
    fields_to_concat = ['keywords', 'questions', 'concepts', 'thesis', 'search_boost']
    bm25_docs = {}
    for i, sub_num in enumerate(all_loaded_texts['submission_numbers']):
        doc_text = " ".join(all_loaded_texts[field][i] for field in fields_to_concat if field in all_loaded_texts and i < len(all_loaded_texts[field]))
        bm25_docs[sub_num] = doc_text

    bm25_baseline = BM25Baseline(bm25_docs)

    # Load test queries
    try:
        test_cases = load_test_queries(args.query_csv)
        logger(f"Loaded {len(test_cases)} test cases from {args.query_csv}")
    except Exception as e:
        logger(f"Error loading test queries: {e}")
        return
    allowed_submissions = set(tc['expected_submission'] for tc in test_cases)

    # Default weights (you can modify these based on your current best configuration)
    default_weights = {
        'summary': 0, 'keywords': 0, 'questions': 0.03,
        'concepts': 0.07, 'thesis': 0.1, 'search_boost': 0.1,
        'query_match_1': 0.25, 'query_match_2': 0.25, 'query_match_3': 0.2
    }

    final_weights = passed_weights if passed_weights is not None else default_weights
    
    logger("Configuration:")
    logger(f"  Embedding weights: {final_weights}")
    logger(f"  Hybrid search: 60% embedding + 40% BM25")
    logger("")
    
    # Run evaluation
    run_database_evaluation(test_cases, model, all_loaded_texts, all_loaded_indexes, 
                          bm25_baseline, args.k_values, logger, custom_weights=final_weights, allowed_submissions=allowed_submissions)

    log_file.close()
    print(f"\nDetailed log saved to: {log_filepath}")

if __name__ == '__main__':
    main()