import concurrent.futures
import os
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, load_dataset
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://api.openai.com/v1"
EMBEDDING_MODELS = {
    "ada": "text-embedding-ada-002",
    "small": "text-embedding-3-small",
    "large": "text-embedding-3-large",
}
DATASET_NAME = "yixuantt/MultiHopRAG"
DATASET_SUBSETS = ["MultiHopRAG", "corpus"]  # Process both subsets
EMBEDDING_DIM = 1536
OUTPUT_DIR = "embeddings"
# --- Concurrency Configuration ---
MAX_WORKERS = (
    32  # Adjust this based on your API rate limits and machine capabilities
)

# --- Create OpenAI Client ---
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)


def get_embedding(task_args):
    """
    Generates an embedding for a given text using the specified model.
    Includes retry logic and progressive truncation for handling potential API errors.
    Returns None if embedding fails (to be skipped later).
    """
    index, text, model, dimensions = task_args

    # Return None for empty text to skip this sample
    if not text:
        return index, None

    # Clean text
    current_text = text.replace("\n", " ").strip()

    max_retries = 5
    retry_delay = 5  # seconds
    length_error_occurred = False

    for attempt in range(max_retries):
        try:
            if model == EMBEDDING_MODELS["large"]:
                response = client.embeddings.create(
                    input=[current_text], model=model, dimensions=dimensions
                )
            else:
                response = client.embeddings.create(
                    input=[current_text], model=model
                )
            return index, response.data[0].embedding

        except Exception as e:
            error_message = str(e).lower()

            # Check if error is related to text length/tokens
            if any(
                keyword in error_message
                for keyword in ["token", "length", "too long", "maximum"]
            ):
                if not length_error_occurred:
                    print(f"Length-related error for sample {index}: {e}")
                    print(
                        f"Current text length: {len(current_text)} characters"
                    )
                    length_error_occurred = True

                # Try progressive truncation: 8192 -> 8000 -> give up
                if len(current_text) > 8192:
                    current_text = current_text[:8192]
                    print(
                        f"Sample {index}: Truncated to 8192 characters, retrying..."
                    )
                    continue
                elif len(current_text) > 8000:
                    current_text = current_text[:8000]
                    print(
                        f"Sample {index}: Truncated to 8000 characters, retrying..."
                    )
                    continue
                else:
                    print(
                        f"Sample {index}: Text still too long after truncation, skipping sample"
                    )
                    return index, None
            else:
                # For other errors, use normal retry logic
                if attempt == max_retries - 1:
                    print(
                        f"Sample {index}: Non-length error on final attempt: {e}"
                    )
                else:
                    print(
                        f"Sample {index}: Non-length error on attempt {attempt + 1}: {e}. Retrying..."
                    )
                time.sleep(retry_delay)

    print(
        f"Sample {index}: Failed to get embedding after {max_retries} retries, skipping sample"
    )
    return index, None


def get_embeddings_ordered(
    texts: List[str], model_name: str, model_key: str
) -> Dict[int, List[float]]:
    """
    Generate embeddings for a list of texts while maintaining order.
    Uses concurrent processing with proper ordering.
    Returns a dictionary mapping original indices to embeddings (skips failed samples).
    """
    print(
        f"\n--- Generating embeddings for model: {model_name} (using up to {MAX_WORKERS} workers) ---"
    )

    # Prepare arguments for each concurrent task with index to maintain order
    tasks = [
        (i, text, model_name, EMBEDDING_DIM) for i, text in enumerate(texts)
    ]

    # Dictionary to store results with their original indices
    results_dict = {}
    skipped_count = 0

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Use tqdm to create a progress bar as the tasks complete
        results_iterator = executor.map(get_embedding, tasks)
        for index, embedding in tqdm(
            results_iterator,
            total=len(texts),
            desc=f"Embedding with {model_key}",
        ):
            if embedding is not None:
                results_dict[index] = embedding
            else:
                skipped_count += 1

    print(
        f"Successfully embedded {len(results_dict)} samples, skipped {skipped_count} samples"
    )
    return results_dict


def create_corpus_text(item):
    """
    Concatenate category, title, and body to create the text for embedding corpus items.
    """
    sorted_dict = dict(sorted(item.items()))
    parts = [f"{key}: {value}" for key, value in sorted_dict.items()]

    return "\n".join(parts)


def create_embedded_dataset_with_multiple_models(
    original_dataset: Dataset,
    all_embeddings: Dict[str, Dict[int, List[float]]],
    subset_name: str,
) -> Dataset:
    """
    Create a new dataset with original data plus embedding columns for all models.
    Only includes samples that have successful embeddings for ALL models.
    """
    # Find samples that have embeddings for ALL models
    common_indices = None
    for model_key in EMBEDDING_MODELS.keys():
        if model_key in all_embeddings:
            model_indices = set(all_embeddings[model_key].keys())
            if common_indices is None:
                common_indices = model_indices
            else:
                common_indices = common_indices.intersection(model_indices)

    if common_indices is None:
        common_indices = set()

    print(f"Found {len(common_indices)} samples with embeddings for all models")

    # Convert original dataset to list of dictionaries, only for successful embeddings
    data_list = []
    for i, item in enumerate(original_dataset):
        if i in common_indices:  # Only include samples with all embeddings
            new_item = dict(item)  # Copy all original fields

            # Add embedding columns with the new naming convention
            for model_key in EMBEDDING_MODELS.keys():
                if model_key in all_embeddings:
                    new_item[f"embedding_{model_key}"] = all_embeddings[
                        model_key
                    ][i]

            data_list.append(new_item)

    # Create new dataset
    return Dataset.from_list(data_list)


def process_subset(subset_name: str):
    """
    Process a single subset (either 'MultiHopRAG' or 'corpus').
    """
    print(f"\n=== Processing subset: {subset_name} ===")

    # --- Load Dataset ---
    print(f"Loading dataset: {DATASET_NAME}, subset: {subset_name}")
    try:
        dataset = load_dataset(DATASET_NAME, name=subset_name, split="train")
        print(f"Dataset loaded successfully with {len(dataset)} samples.")
    except Exception as e:
        print(f"Failed to load dataset subset '{subset_name}'. Error: {e}")
        return None

    # --- Prepare all texts to be embedded ---
    texts_to_embed = []
    preprocessing_stats = {"successful": 0, "empty": 0}

    for i, item in enumerate(dataset):
        if subset_name == "MultiHopRAG":
            # For queries, use the 'query' column directly
            text = item.get("query", "")
        elif subset_name == "corpus":
            # For corpus, concatenate category, title, and body
            text = create_corpus_text(item)
        else:
            text = ""

        # Ensure text is a string, default to empty string if not
        if not isinstance(text, str):
            text = ""
            preprocessing_stats["empty"] += 1
        elif text.strip():
            preprocessing_stats["successful"] += 1
        else:
            preprocessing_stats["empty"] += 1

        texts_to_embed.append(text)

        # Log first few examples for verification
        if i < 3:
            print(f"\nExample {i+1} preprocessing:")
            print(f"  Text length: {len(text)} chars")
            print(f"  Text preview: {text[:200]}...")

    print(f"\nPreprocessing summary for {subset_name}:")
    print(f"  - Successfully processed: {preprocessing_stats['successful']}")
    print(f"  - Empty documents: {preprocessing_stats['empty']}")
    print(f"  - Total: {len(texts_to_embed)}")

    # Dictionary to store embeddings for all models
    all_embeddings = {}

    # --- Process and Embed for Each Model ---
    for model_key, model_name in EMBEDDING_MODELS.items():
        print(f"\n--- Processing {subset_name} with {model_name} ---")

        # --- Generate Embeddings (maintaining order) ---
        embeddings_dict = get_embeddings_ordered(
            texts_to_embed, model_name, model_key
        )
        all_embeddings[model_key] = embeddings_dict

    # --- Create Final Dataset with All Embedding Columns ---
    embedded_dataset = create_embedded_dataset_with_multiple_models(
        dataset, all_embeddings, subset_name
    )

    print(
        f"Created embedded dataset for '{subset_name}' with {len(embedded_dataset)} rows"
    )

    # Print column information
    if len(embedded_dataset) > 0:
        columns = list(embedded_dataset[0].keys())
        embedding_columns = [
            col for col in columns if col.startswith("embedding_")
        ]
        print(f"  - Original columns: {len(columns) - len(embedding_columns)}")
        print(f"  - Embedding columns: {embedding_columns}")

    return embedded_dataset


def main():
    """
    Main function to load datasets, generate embeddings for each model and subset,
    and save the results.
    """
    print("--- Starting MultiHopRAG Dataset Embedding Process ---")

    # --- Create Output Directory ---
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    # Define output paths
    output_paths = {
        "MultiHopRAG": os.path.join(OUTPUT_DIR, "mhr_queries_embedded"),
        "corpus": os.path.join(OUTPUT_DIR, "mhr_tools_embedded"),
    }

    # --- Process Each Subset ---
    for subset_name in DATASET_SUBSETS:
        output_path = output_paths[subset_name]

        # Check if dataset already exists
        if os.path.exists(output_path):
            print(
                f"\nEmbedded dataset already exists at '{output_path}'. Skipping {subset_name}."
            )
            continue

        # Process the subset
        embedded_dataset = process_subset(subset_name)

        if embedded_dataset is not None:
            # --- Save Dataset ---
            embedded_dataset.save_to_disk(output_path)
            print(f"Dataset saved to '{output_path}'")

            # Print summary
            print(f"Summary for {subset_name}:")
            print(f"  - Total rows: {len(embedded_dataset)}")
            if len(embedded_dataset) > 0:
                columns = list(embedded_dataset[0].keys())
                embedding_columns = [
                    col for col in columns if col.startswith("embedding_")
                ]
                print(f"  - Total columns: {len(columns)}")
                print(
                    f"  - Embedding columns: {len(embedding_columns)} ({', '.join(embedding_columns)})"
                )

    print("\n--- All MultiHopRAG embedding tasks completed successfully! ---")


if __name__ == "__main__":
    main()
