import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import argparse
import os
import pickle
import json
from collections import defaultdict
from bm25_baseline import BM25Baseline

# create_embeddings.py
# Configuration
# MODEL_NAME = 'abhinand/MedEmbed-base-v0.1'
MODEL_NAME = 'abhinand/MedEmbed-small-v0.1'

SUMMARY_COLUMN_NAME = 'summary'
KEYWORDS_COLUMN_NAME = 'summary_keywords'
QUESTIONS_COLUMN_NAME = 'generated_questions' # This is already the new name
CONCEPTS_COLUMN_NAME = 'concepts'
THESIS_COLUMN_NAME = 'thesis'
SEARCH_BOOST_COLUMN_NAME = 'search_boost_text'
SUBMISSION_NUMBER_COLUMN_NAME = 'submission_number'
QUERY_MATCH_1_COLUMN_NAME = 'query_match_1'
QUERY_MATCH_2_COLUMN_NAME = 'query_match_2'
QUERY_MATCH_3_COLUMN_NAME = 'query_match_3'

def load_data(csv_path):
    """Loads data from the CSV file."""
    try:
        df = pd.read_csv(csv_path)
        required_columns = [SUMMARY_COLUMN_NAME, KEYWORDS_COLUMN_NAME, QUESTIONS_COLUMN_NAME,
                            CONCEPTS_COLUMN_NAME, THESIS_COLUMN_NAME, SEARCH_BOOST_COLUMN_NAME,
                            SUBMISSION_NUMBER_COLUMN_NAME, QUERY_MATCH_1_COLUMN_NAME,
                            QUERY_MATCH_2_COLUMN_NAME, QUERY_MATCH_3_COLUMN_NAME]
        for col in required_columns:
            if col not in df.columns:
                # For new columns, initialize them if missing
                if col in [CONCEPTS_COLUMN_NAME, THESIS_COLUMN_NAME, SEARCH_BOOST_COLUMN_NAME,
                           QUERY_MATCH_1_COLUMN_NAME, QUERY_MATCH_2_COLUMN_NAME, QUERY_MATCH_3_COLUMN_NAME]:
                    print(f"Warning: Column '{col}' not found in CSV. Initializing as empty.")
                    df[col] = ''
                else:
                    raise ValueError(f"CSV must contain '{col}' column.")
            
            # Fill NaN only for text columns that will be embedded
            if col in [SUMMARY_COLUMN_NAME, KEYWORDS_COLUMN_NAME, QUESTIONS_COLUMN_NAME, 
                       CONCEPTS_COLUMN_NAME, THESIS_COLUMN_NAME, SEARCH_BOOST_COLUMN_NAME,
                       QUERY_MATCH_1_COLUMN_NAME, QUERY_MATCH_2_COLUMN_NAME, QUERY_MATCH_3_COLUMN_NAME]:
                df[col] = df[col].fillna('')
            elif col == SUBMISSION_NUMBER_COLUMN_NAME: # Submission number might be numeric or string, keep as is or fill with a placeholder if needed
                df[col] = df[col].fillna('N/A')
        
        # Filter out rows with invalid thesis statements
        if THESIS_COLUMN_NAME in df.columns:
            # Ensure thesis column is string, strip whitespace, and convert to lowercase for comparison
            df[THESIS_COLUMN_NAME] = df[THESIS_COLUMN_NAME].astype(str).str.strip().str.lower()
            # Define invalid thesis values
            invalid_thesis_values = ['', 'error', 'none']
            # Create a mask for rows to keep
            mask = ~df[THESIS_COLUMN_NAME].isin(invalid_thesis_values)
            
            original_row_count = len(df)
            df = df[mask]
            filtered_row_count = len(df)
            print(f"Filtered out {original_row_count - filtered_row_count} rows due to invalid thesis statements.")
            if filtered_row_count == 0:
                print("Warning: After filtering, no rows remain. Check your CSV and thesis column content.")
                # Potentially raise an error or return None if no valid data remains
                # For now, we'll return the empty DataFrame, and subsequent steps will handle empty lists.

        return df
    except FileNotFoundError:
        print(f"Error: The file {csv_path} was not found.")
        return None
    except Exception as e:
        print(f"Error loading data: {e}")
        return None

def create_and_save_index(texts, model, index_path, texts_path, column_name, index=None):
    """Generates embeddings and creates or updates a FAISS index."""
    if not texts:
        print(f"No new text data found for {column_name} to update index.")
        return

    print(f"Generating embeddings for {len(texts)} new {column_name} entries...")
    new_embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    faiss.normalize_L2(new_embeddings)

    if index is None:
        print(f"Creating new FAISS index for {column_name}.")
        dimension = new_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
    
    index.add(new_embeddings)

    print(f"Saving {column_name} FAISS index to {index_path}...")
    faiss.write_index(index, index_path)

    # Load existing texts and append new ones
    if os.path.exists(texts_path):
        with open(texts_path, 'rb') as f:
            existing_texts = pickle.load(f)
    else:
        existing_texts = []
    
    all_texts = existing_texts + texts

    print(f"Saving {len(all_texts)} total {column_name} texts to {texts_path}...")
    with open(texts_path, 'wb') as f:
        pickle.dump(all_texts, f)
    print(f"Successfully updated and saved index for {column_name}.")


def search_hybrid(query, model, all_texts, all_indexes, top_k=5, search_multiplier=3, custom_weights=None):
    """
    Performs a hybrid search across summary, keywords, and questions indexes.
    Returns a list of results sorted by a hybrid similarity score.
    `all_texts` should be a dict: {'summary': summary_texts_list, 'keywords': keyword_texts_list, 'questions': question_texts_list}
    `all_indexes` should be a dict: {'summary': summary_faiss_index, 'keywords': keyword_faiss_index, 'questions': question_faiss_index}
    """
    # print(f"Generating query embedding for: '{query}'")
    query_embedding = model.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(query_embedding)

    # Store individual similarities for each document found
    # Key: original document index, Value: {'s': score, 'k': score, 'q': score, 'c': score, 't': score, 'sb': score, 'sub_num': submission_number}
    doc_similarities = {} 
    submission_numbers_list = all_texts.get('submission_numbers', [])

    # Define weights for each field
    # current_weights = {
    #     'summary': 0, 'keywords': 0.07, 'questions': 0.03,
    #     'concepts': 0.0, 'thesis': 0.1, 'search_boost': 0.1,
    #     'query_match_1': 0.25, 'query_match_2': 0.25, 'query_match_3': 0.2
    # }

    current_weights = {
            'summary': 0,
            'concepts': 0,
            'keywords': 0.134207,
            'questions': 0.226103,
            'thesis': 0.094972,
            'search_boost': 0.029563,
            'query_match_1': 0.217395,
            'query_match_2': 0.241111,
            'query_match_3': 0.056650,
    }

    if custom_weights:
        current_weights = custom_weights
        # print(f"Using custom weights: {current_weights}") # Optional: for debugging

    # Search each index
    for field_type, index in all_indexes.items():
        if index is None:
            print(f"Skipping search for {field_type} as index is not loaded.")
            continue
        # if field_type not in weights: # We want to search all, even if weight is 0 for now
        #     continue

        # print(f"Searching {field_type} index...")
        # Fetch more results initially to get a good pool for hybrid scoring
        distances, indices = index.search(query_embedding, top_k * search_multiplier)
        
        for i in range(len(indices[0])):
            original_doc_idx = indices[0][i]
            similarity_score = distances[0][i]
            
            if original_doc_idx not in doc_similarities:
                # Initialize with submission number if available
                sub_num = submission_numbers_list[original_doc_idx] if original_doc_idx < len(submission_numbers_list) else "N/A"
                doc_similarities[original_doc_idx] = {'s': 0.0, 'k': 0.0, 'q': 0.0, 'c': 0.0, 't': 0.0, 'sb': 0.0, 'qm1': 0.0, 'qm2': 0.0, 'qm3': 0.0, 'sub_num': sub_num}
            
            if field_type == 'summary':
                doc_similarities[original_doc_idx]['s'] = similarity_score
            elif field_type == 'keywords':
                doc_similarities[original_doc_idx]['k'] = similarity_score
            elif field_type == 'questions':
                doc_similarities[original_doc_idx]['q'] = similarity_score
            elif field_type == 'concepts':
                doc_similarities[original_doc_idx]['c'] = similarity_score
            elif field_type == 'thesis':
                doc_similarities[original_doc_idx]['t'] = similarity_score
            elif field_type == 'search_boost':
                doc_similarities[original_doc_idx]['sb'] = similarity_score
            elif field_type == 'query_match_1':
                doc_similarities[original_doc_idx]['qm1'] = similarity_score
            elif field_type == 'query_match_2':
                doc_similarities[original_doc_idx]['qm2'] = similarity_score
            elif field_type == 'query_match_3':
                doc_similarities[original_doc_idx]['qm3'] = similarity_score
    
    hybrid_results = []
    # Use length of submission numbers list as the reference for num_docs, assuming it matches other text lists
    num_docs = len(submission_numbers_list) 

    for original_doc_idx, data in doc_similarities.items():
        if original_doc_idx < 0 or original_doc_idx >= num_docs: # faiss can return -1 if not enough results
            continue

        # Weighted hybrid score
        hybrid_score = (current_weights.get('summary', 0.0) * data.get('s', 0.0)) + \
                       (current_weights.get('keywords', 0.0) * data.get('k', 0.0)) + \
                       (current_weights.get('questions', 0.0) * data.get('q', 0.0)) + \
                       (current_weights.get('concepts', 0.0) * data.get('c', 0.0)) + \
                       (current_weights.get('thesis', 0.0) * data.get('t', 0.0)) + \
                       (current_weights.get('search_boost', 0.0) * data.get('sb', 0.0)) + \
                       (current_weights.get('query_match_1', 0.0) * data.get('qm1', 0.0)) + \
                       (current_weights.get('query_match_2', 0.0) * data.get('qm2', 0.0)) + \
                       (current_weights.get('query_match_3', 0.0) * data.get('qm3', 0.0))

        # Ensure we have text to display, default to summary
        display_text = all_texts['summary'][original_doc_idx] if original_doc_idx < len(all_texts['summary']) else "N/A"
        
        hybrid_results.append({
            'original_index': original_doc_idx,
            'submission_number': data['sub_num'],
            'hybrid_similarity': hybrid_score,
            'summary_similarity': data.get('s', 0.0),
            'keywords_similarity': data.get('k', 0.0),
            'questions_similarity': data.get('q', 0.0),
            'concepts_similarity': data.get('c', 0.0),
            'thesis_similarity': data.get('t', 0.0),
            'search_boost_similarity': data.get('sb', 0.0),
            'query_match_1_similarity': data.get('qm1', 0.0),
            'query_match_2_similarity': data.get('qm2', 0.0),
            'query_match_3_similarity': data.get('qm3', 0.0),
            'text': display_text
        })

    # Sort by hybrid similarity descending
    hybrid_results.sort(key=lambda x: x['hybrid_similarity'], reverse=True)
    
    return hybrid_results[:top_k]


def search_embedding_bm25(query, model, all_texts, all_indexes, top_k=5, custom_weights=None, lambda_val=0.5):
    """
    Performs a hybrid search using embeddings and BM25, weighted by lambda_val.
    """
    embedding_weight = lambda_val
    bm25_weight = 1 - lambda_val

    # 1. Get Embedding Results
    embedding_results = search_hybrid(query, model, all_texts, all_indexes, top_k=top_k*5, custom_weights=custom_weights) # Fetch more to ensure overlap
    
    embedding_scores = {}
    if embedding_results:
        max_emb_score = max(res['hybrid_similarity'] for res in embedding_results) if embedding_results else 1.0
        min_emb_score = min(res['hybrid_similarity'] for res in embedding_results) if embedding_results else 0.0
        score_range = max_emb_score - min_emb_score if max_emb_score != min_emb_score else 1.0
        
        for res in embedding_results:
            sub_num = str(res['submission_number'])
            normalized_score = (res['hybrid_similarity'] - min_emb_score) / score_range if score_range > 0 else 1.0
            embedding_scores[sub_num] = normalized_score

    # 2. Get BM25 Results
    fields_to_concat = ['keywords', 'questions', 'concepts', 'thesis', 'search_boost']
    bm25_docs = {}
    for i, sub_num in enumerate(all_texts['submission_numbers']):
        doc_text = " ".join(all_texts[field][i] for field in fields_to_concat if field in all_texts and i < len(all_texts[field]))
        bm25_docs[sub_num] = doc_text
    
    bm25_baseline = BM25Baseline(bm25_docs)
    tokenized_query = query.lower().split()
    bm25_doc_scores = bm25_baseline.bm25.get_scores(tokenized_query)

    bm25_scores = {}
    if bm25_doc_scores is not None and len(bm25_doc_scores) > 0:
        max_bm25_score = max(bm25_doc_scores)
        min_bm25_score = min(bm25_doc_scores)
        bm25_score_range = max_bm25_score - min_bm25_score if max_bm25_score != min_bm25_score else 1.0

        for i, score in enumerate(bm25_doc_scores):
            sub_num = bm25_baseline.doc_map[i]
            normalized_score = (score - min_bm25_score) / bm25_score_range if bm25_score_range > 0 else 1.0
            bm25_scores[sub_num] = normalized_score

    # 3. Combine Scores
    combined_scores = defaultdict(float)
    all_submissions = set(embedding_scores.keys()) | set(bm25_scores.keys())

    for sub_num in all_submissions:
        emb_score = embedding_scores.get(sub_num, 0.0)
        bm25_score = bm25_scores.get(sub_num, 0.0)
        combined_scores[sub_num] = (embedding_weight * emb_score) + (bm25_weight * bm25_score)

    sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
    
    # Format output similar to search_hybrid
    final_results = []
    for sub_num, score in sorted_results[:top_k]:
        # Find original result to get more details
        original_res = next((res for res in embedding_results if res['submission_number'] == sub_num), None)
        if original_res:
            final_results.append({
                'submission_number': sub_num,
                'hybrid_similarity': score,
                'text': original_res.get('text', 'N/A'),
                'original_index': original_res.get('original_index', -1)
            })
        else: # Fallback if not in embedding results
             final_results.append({
                'submission_number': sub_num,
                'hybrid_similarity': score,
                'text': 'N/A',
                'original_index': -1
            })

    return final_results


def main():
    """Main function to handle command-line arguments and execute the script's logic."""
    parser = argparse.ArgumentParser(description="Create and search embeddings for FDA AI records.")
    parser.add_argument('--csv_path', type=str, required=True, help="Path to the input CSV file.")
    parser.add_argument('--output_dir', type=str, required=True, help="Directory to save the output files.")
    parser.add_argument('--generate', action='store_true', help="Generate and save embeddings and FAISS indexes from scratch.")
    parser.add_argument('--update', action='store_true', help="Incrementally update embeddings and FAISS indexes.")
    parser.add_argument('--query', type=str, help="Text to search for (performs a hybrid search).")
    parser.add_argument('--top_k', type=int, default=5, help="Number of top results to return for hybrid search.")
    parser.add_argument('--weights_json', type=str, help="JSON string of weights to use for search_hybrid. Example: '{\"summary\": 0.1, \"keywords\": 0.2, ...}'")
    parser.add_argument('--lambda_val', type=float, help="Lambda value for weighted hybrid search (embedding vs. BM25). If provided, triggers BM25 hybrid search.")

    args = parser.parse_args()

    passed_weights = None
    if args.weights_json:
        try:
            passed_weights = json.loads(args.weights_json)
            print(f"Received custom weights via --weights_json: {passed_weights}")
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON string passed to --weights_json: {args.weights_json}")
            print("Using default weights.")

    print("Loading Sentence Transformer model...")
    model = SentenceTransformer(MODEL_NAME)
    print("Model loaded.")

    if args.generate or args.update:
        run_embedding_pipeline(model, args.csv_path, args.output_dir, update=args.update)

    elif args.query:
        run_search_pipeline(model, args.query, args.output_dir, args.top_k, passed_weights, args.lambda_val)

    else:
        parser.print_help()
        print("\nExample usage:")
        print("  To generate all embeddings: python create_embeddings.py --generate --csv_path /path/to/your.csv --output_dir /path/to/output")
        print("  To perform a hybrid search: python create_embeddings.py --query \"your search query\" --output_dir /path/to/output")

def run_embedding_pipeline(model, csv_path, output_dir, update=False):
    """Handles the logic for generating or updating embeddings."""
    print(f"Loading data from {csv_path}...")
    df = load_data(csv_path)
    if df is None:
        return

    os.makedirs(output_dir, exist_ok=True)

    existing_submission_numbers = []
    if update:
        submission_numbers_path = os.path.join(output_dir, 'submission_numbers.pkl')
        if os.path.exists(submission_numbers_path):
            print("Loading existing submission numbers for update...")
            with open(submission_numbers_path, 'rb') as f:
                existing_submission_numbers = pickle.load(f)
            print(f"Found {len(existing_submission_numbers)} existing records.")

    df_new = df[~df[SUBMISSION_NUMBER_COLUMN_NAME].isin(existing_submission_numbers)]
    if df_new.empty:
        print("No new records to process.")
        return
    else:
        print(f"Found {len(df_new)} new records to process.")

    target_df = df_new if update and existing_submission_numbers else df

    summary_texts = target_df[SUMMARY_COLUMN_NAME].tolist()
    keyword_texts = target_df[KEYWORDS_COLUMN_NAME].tolist()
    question_texts = target_df[QUESTIONS_COLUMN_NAME].tolist()
    concept_texts = target_df[CONCEPTS_COLUMN_NAME].tolist()
    thesis_texts = target_df[THESIS_COLUMN_NAME].tolist()
    search_boost_texts = target_df[SEARCH_BOOST_COLUMN_NAME].tolist()
    query_match_1_texts = target_df[QUERY_MATCH_1_COLUMN_NAME].tolist()
    query_match_2_texts = target_df[QUERY_MATCH_2_COLUMN_NAME].tolist()
    query_match_3_texts = target_df[QUERY_MATCH_3_COLUMN_NAME].tolist()
    submission_numbers = target_df[SUBMISSION_NUMBER_COLUMN_NAME].tolist()

    fields_to_process = {
        "summary": (summary_texts, 'summary_index.faiss', 'summary_texts.pkl'),
        "keywords": (keyword_texts, 'keywords_index.faiss', 'keywords_texts.pkl'),
        "questions": (question_texts, 'questions_index.faiss', 'questions_texts.pkl'),
        "concepts": (concept_texts, 'concepts_index.faiss', 'concepts_texts.pkl'),
        "thesis": (thesis_texts, 'thesis_index.faiss', 'thesis_texts.pkl'),
        "search_boost": (search_boost_texts, 'search_boost_index.faiss', 'search_boost_texts.pkl'),
        "query_match_1": (query_match_1_texts, 'query_match_1_index.faiss', 'query_match_1_texts.pkl'),
        "query_match_2": (query_match_2_texts, 'query_match_2_index.faiss', 'query_match_2_texts.pkl'),
        "query_match_3": (query_match_3_texts, 'query_match_3_index.faiss', 'query_match_3_texts.pkl'),
    }

    for name, (texts, index_file, texts_file) in fields_to_process.items():
        index_path = os.path.join(output_dir, index_file)
        texts_path = os.path.join(output_dir, texts_file)
        
        index = None
        if update and os.path.exists(index_path):
            print(f"Loading existing index for {name}...")
            index = faiss.read_index(index_path)

        create_and_save_index(texts, model, index_path, texts_path, name, index=index)

    submission_numbers_path = os.path.join(output_dir, 'submission_numbers.pkl')
    all_submission_numbers = existing_submission_numbers + submission_numbers if update else submission_numbers

    if all_submission_numbers:
        print(f"Saving {len(all_submission_numbers)} total submission numbers to {submission_numbers_path}...")
        with open(submission_numbers_path, 'wb') as f:
            pickle.dump(all_submission_numbers, f)
        print("Successfully saved submission numbers.")
    else:
        print("No submission numbers found to save.")

    print("Embedding processing complete.")

def run_search_pipeline(model, query, output_dir, top_k, custom_weights, lambda_val=None):
    """Handles the logic for performing a search."""
    all_loaded_texts = {}
    all_loaded_indexes = {}

    fields_to_load = {
        "summary": ('summary_index.faiss', 'summary_texts.pkl'),
        "keywords": ('keywords_index.faiss', 'keywords_texts.pkl'),
        "questions": ('questions_index.faiss', 'questions_texts.pkl'),
        "concepts": ('concepts_index.faiss', 'concepts_texts.pkl'),
        "thesis": ('thesis_index.faiss', 'thesis_texts.pkl'),
        "search_boost": ('search_boost_index.faiss', 'search_boost_texts.pkl'),
        "query_match_1": ('query_match_1_index.faiss', 'query_match_1_texts.pkl'),
        "query_match_2": ('query_match_2_index.faiss', 'query_match_2_texts.pkl'),
        "query_match_3": ('query_match_3_index.faiss', 'query_match_3_texts.pkl'),
    }

    for name, (index_file, texts_file) in fields_to_load.items():
        index_path = os.path.join(output_dir, index_file)
        texts_path = os.path.join(output_dir, texts_file)
        if os.path.exists(index_path) and os.path.exists(texts_path):
            all_loaded_indexes[name] = faiss.read_index(index_path)
            with open(texts_path, 'rb') as f:
                all_loaded_texts[name] = pickle.load(f)
        else:
            print(f"{name.capitalize()} index/texts not found. Please generate first.")
            all_loaded_indexes[name] = None
            all_loaded_texts[name] = []

    submission_numbers_file = os.path.join(output_dir, 'submission_numbers.pkl')
    if os.path.exists(submission_numbers_file):
        with open(submission_numbers_file, 'rb') as f:
            all_loaded_texts['submission_numbers'] = pickle.load(f)
    else:
        print("Submission numbers file not found. Please generate first.")
        all_loaded_texts['submission_numbers'] = []

    # Basic validation
    if not any(all_loaded_texts.values()) or not any(all_loaded_indexes.values()):
        print("No data or indexes found. Cannot perform search.")
        return

    if lambda_val is not None:
        print(f"Performing hybrid embedding/BM25 search with lambda = {lambda_val}...")
        results = search_embedding_bm25(query, model, all_loaded_texts, all_loaded_indexes, top_k, custom_weights=custom_weights, lambda_val=lambda_val)
    else:
        print("Performing embedding-only search...")
        results = search_hybrid(query, model, all_loaded_texts, all_loaded_indexes, top_k, custom_weights=custom_weights)

    if results:
        print(f"\nTop {len(results)} results for '{query}':")
        for i, res in enumerate(results):
            print(f"{i+1}. Submission: {res['submission_number']}, Score: {res['hybrid_similarity']:.4f}")
            print(f"   Text (Index {res['original_index']}): {res['text'][:200]}...")
    else:
        print("No results found or error during search.")

if __name__ == '__main__':
    main()
