import concurrent.futures
import hashlib
import os
import pickle
import time
from typing import Dict, Generator, List

from datasets import Dataset, load_dataset
from loguru import logger
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables.
# Example: export OPENAI_API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "OPENAI_API_KEY environment variable not set. Please set it before running."
    )

BASE_URL = "https://api.openai.com/v1"  # Or your custom API endpoint
EMBEDDING_MODELS = {
    "large": "text-embedding-3-large",
}
EMBEDDING_DIM = 1536  # Required for text-embedding-3-large
OUTPUT_DIR = "mhr_decomposed_embeddings"
CACHE_DIR = ".embedding_cache"

# --- Concurrency Configuration ---
MAX_WORKERS = (
    32  # Adjust based on your API rate limits and machine capabilities
)

# --- Create OpenAI Client ---
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)


# ==============================================================================
# SECTION 1: DATA PREPARATION AND EXPANSION FUNCTIONS
# ==============================================================================


def format_query(original_question: str, sub_question: str) -> str:
    """
    Formats the original and sub-question into a compact, effective query string.
    """
    return (
        f"Context: {original_question.strip()} | Focus: {sub_question.strip()}"
    )


def generate_expanded_rows(dataset: Dataset) -> Generator[Dict, None, None]:
    """
    Python generator that iterates through a dataset, expands each row based on
    'decomposed_questions', and yields a new formatted row.
    """
    logger.info("Starting dataset expansion...")
    for example in tqdm(dataset, desc="Expanding rows"):
        # For each sub-question in the list, yield a complete new row
        for sub_question in example["decomposed_questions"]:
            # Start building the new row dictionary
            new_row = {
                "formatted_query": format_query(example["query"], sub_question)
            }
            # Copy all other columns from the original row
            for key, value in example.items():
                if key not in ["query", "decomposed_questions"]:
                    new_row[key] = value

            yield new_row


# ==============================================================================
# SECTION 2: EMBEDDING GENERATION FUNCTIONS (Unchanged from original script)
# ==============================================================================


def get_embedding(task_args: tuple) -> tuple[int, list[float] | None]:
    """
    Generates an embedding for a given text using the specified model.
    Includes retry logic, caching, and progressive truncation.
    """
    index, text, model, dimensions = task_args

    if not text or not isinstance(text, str):
        logger.warning(f"Invalid text at index {index}: {text}. Skipping.")
        return index, None

    # 1. Create a unique cache key
    cache_key_str = f"{text}-{model}-{dimensions}"
    cache_key = hashlib.sha256(cache_key_str.encode("utf-8")).hexdigest()
    cache_path = os.path.join(CACHE_DIR, cache_key)

    # 2. Check cache
    if os.path.exists(cache_path):
        with open(cache_path, "rb") as f:
            return index, pickle.load(f)

    # 3. If not in cache, call API
    current_text = text.replace("\n", " ").strip()
    if not current_text:
        return index, None

    max_retries = 5
    retry_delay = 5  # seconds

    for attempt in range(max_retries):
        try:
            params = {"input": [current_text], "model": model}
            if "large" in model:
                params["dimensions"] = dimensions

            response = client.embeddings.create(**params)
            embedding = response.data[0].embedding

            # Save to cache
            with open(cache_path, "wb") as f:
                pickle.dump(embedding, f)

            return index, embedding

        except Exception as e:
            logger.error(f"Sample {index} | Attempt {attempt + 1} | Error: {e}")
            if "token" in str(e).lower():
                logger.warning(f"Sample {index}: Text too long, skipping.")
                return index, None
            time.sleep(retry_delay)

    logger.error(f"Sample {index}: Failed after {max_retries} retries.")
    return index, None


def get_embeddings_ordered(
    texts: List[str], model_name: str, model_key: str
) -> Dict[int, List[float]]:
    """
    Generates embeddings for a list of texts concurrently, maintaining order.
    """
    logger.info(f"Generating embeddings for model: {model_name}...")
    tasks = [
        (i, text, model_name, EMBEDDING_DIM) for i, text in enumerate(texts)
    ]
    results_dict = {}

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        results_iterator = executor.map(get_embedding, tasks)
        for index, embedding in tqdm(
            results_iterator,
            total=len(texts),
            desc=f"Embedding with {model_key}",
        ):
            if embedding is not None:
                results_dict[index] = embedding

    skipped_count = len(texts) - len(results_dict)
    logger.success(
        f"Embedded {len(results_dict)} samples, skipped {skipped_count}."
    )
    return results_dict


def create_final_dataset(
    original_dataset: Dataset, all_embeddings: Dict[str, Dict[int, List[float]]]
) -> Dataset:
    """
    Creates a new dataset with original data plus new embedding columns.
    """
    common_indices = set(range(len(original_dataset)))
    for embeddings_dict in all_embeddings.values():
        common_indices.intersection_update(embeddings_dict.keys())

    logger.info(
        f"Found {len(common_indices)} samples with successful embeddings for all models."
    )

    data_list = []
    for i in sorted(list(common_indices)):
        item = original_dataset[i]
        new_item = dict(item)
        for model_key, embeddings_dict in all_embeddings.items():
            new_item[f"embedding_{model_key}"] = embeddings_dict[i]
        data_list.append(new_item)

    return Dataset.from_list(data_list)


# ==============================================================================
# SECTION 3: MAIN EXECUTION LOGIC
# ==============================================================================


def main():
    """
    Main function to load, expand, embed, and save the dataset.
    """
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)

    logger.info("--- Starting Decomposed Query Embedding Process ---")

    # STEP 1: Load the original dataset
    logger.info("Loading original dataset: <ANONYMIZED>")
    original_ds = load_dataset("<ANONYMIZED>", split="train")
    logger.info(f"Loaded {len(original_ds)} original rows.")

    # STEP 2: Expand the dataset using the generator
    # We pass the generator and its arguments to from_generator
    expanded_ds = Dataset.from_generator(
        generate_expanded_rows, gen_kwargs={"dataset": original_ds}
    )
    logger.success(f"Dataset expanded to {len(expanded_ds)} rows.")

    # STEP 3: Get the text to be embedded from the new column
    texts_to_embed = expanded_ds["formatted_query"]

    # STEP 4: Generate embeddings for all configured models
    all_embeddings = {}
    for model_key, model_name in EMBEDDING_MODELS.items():
        embeddings_dict = get_embeddings_ordered(
            texts_to_embed, model_name, model_key
        )
        all_embeddings[model_key] = embeddings_dict

    # STEP 5: Create the final dataset with embedding columns
    embedded_dataset = create_final_dataset(expanded_ds, all_embeddings)

    # STEP 6: Save the final dataset to disk
    if len(embedded_dataset) > 0:
        output_name = "mhr_decomposed_embedded"
        output_path = os.path.join(OUTPUT_DIR, output_name)
        logger.info(
            f"Saving final dataset with {len(embedded_dataset)} rows to '{output_path}'"
        )
        embedded_dataset.save_to_disk(output_path)

        # Print summary
        logger.info("Summary of the final dataset:")
        logger.info(f"  - Total rows: {len(embedded_dataset)}")
        logger.info(f"  - Columns: {embedded_dataset.column_names}")
    else:
        logger.warning(
            "No samples were successfully embedded. No dataset will be saved."
        )

    logger.info("\n--- All embedding tasks completed! ---")


if __name__ == "__main__":
    main()
