import concurrent.futures
import hashlib
import os
import pickle
import time
from typing import Dict, List

import pandas as pd
from datasets import Dataset, load_dataset
from loguru import logger
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables.
# Example: export OPENAI_API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "OPENAI_API_KEY environment variable not set. Please set it before running."
    )

BASE_URL = "https://api.openai.com/v1"  # Or your custom API endpoint
EMBEDDING_MODELS = {
    # "ada": "text-embedding-ada-002",
    # "small": "text-embedding-3-small",
    "large": "text-embedding-3-large",
}
DATASET_NAME = "mteb/fiqa"
EMBEDDING_DIM = 1536  # Required for text-embedding-3-large
OUTPUT_DIR = "fiqa_embeddings"
CACHE_DIR = ".embedding_cache"

# --- Concurrency Configuration ---
MAX_WORKERS = (
    32  # Adjust based on your API rate limits and machine capabilities
)

# --- Create OpenAI Client ---
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)


def get_embedding(task_args: tuple) -> tuple[int, list[float] | None]:
    """
    Generates an embedding for a given text using the specified model.
    Includes retry logic and progressive truncation for handling potential API errors.
    Returns a tuple of (index, embedding) or (index, None) if embedding fails.
    """
    index, text, model, dimensions = task_args

    if not text or not isinstance(text, str):
        print(index, text)
        return index, None

    # 1. Create a unique cache key based on the function's inputs
    cache_key_str = f"{text}-{model}-{dimensions}"
    cache_key = hashlib.sha256(cache_key_str.encode("utf-8")).hexdigest()
    cache_path = os.path.join(CACHE_DIR, cache_key)

    # 2. Check if the result is already in the cache
    if os.path.exists(cache_path):
        with open(cache_path, "rb") as f:
            cached_embedding = pickle.load(f)
        return index, cached_embedding

    # --- If not in cache, proceed with the original API call logic ---
    current_text = text.replace("\n", " ").strip()
    if not current_text:
        return index, None

    max_retries = 5
    retry_delay = 5  # seconds
    length_error_occurred = False

    for attempt in range(max_retries):
        try:
            params = {"input": [current_text], "model": model}
            if model == EMBEDDING_MODELS["large"]:
                params["dimensions"] = dimensions

            response = client.embeddings.create(**params)
            embedding = response.data[0].embedding

            # 3. Save the new result to the cache before returning
            with open(cache_path, "wb") as f:
                pickle.dump(embedding, f)

            return index, embedding

        except Exception as e:
            error_message = str(e).lower()

            logger.error(e)

            if any(
                keyword in error_message
                for keyword in ["token", "length", "too long", "maximum"]
            ):
                if not length_error_occurred:
                    print(f"Length-related error for sample {index}: {e}")
                    length_error_occurred = True

                # Progressive truncation
                if len(current_text) > 8192:
                    current_text = current_text[:8192]
                    print(
                        f"Sample {index}: Truncated to 8192 chars, retrying..."
                    )
                    continue
                elif len(current_text) > 8000:
                    current_text = current_text[:8000]
                    print(
                        f"Sample {index}: Truncated to 8000 chars, retrying..."
                    )
                    continue
                elif len(current_text) > 7000:
                    current_text = current_text[:7000]
                    print(
                        f"Sample {index}: Truncated to 7000 chars, retrying..."
                    )
                    continue
                elif len(current_text) > 1000:
                    current_text = current_text[:1000]
                    print(
                        f"Sample {index}: Truncated to 4000 chars, retrying..."
                    )
                    continue
                else:
                    print(f"Sample {index}: Text still too long, skipping.")
                    return index, None
            else:
                print(
                    f"Sample {index}: API error on attempt {attempt + 1}: {e}. Retrying in {retry_delay}s..."
                )
                time.sleep(retry_delay)

    print(
        f"Sample {index}: Failed to get embedding after {max_retries} retries, skipping."
    )
    return index, None


def get_embeddings_ordered(
    texts: List[str], model_name: str, model_key: str
) -> Dict[int, List[float]]:
    """
    Generates embeddings for a list of texts using concurrent processing
    while ensuring the original order is maintained.
    """
    print(
        f"\n--- Generating embeddings for model: {model_name} (up to {MAX_WORKERS} workers) ---"
    )
    tasks = [
        (i, text, model_name, EMBEDDING_DIM) for i, text in enumerate(texts)
    ]
    results_dict = {}

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        results_iterator = executor.map(get_embedding, tasks)
        for index, embedding in tqdm(
            results_iterator,
            total=len(texts),
            desc=f"Embedding with {model_key}",
        ):
            if embedding is not None:
                results_dict[index] = embedding

    skipped_count = len(texts) - len(results_dict)
    print(
        f"Successfully embedded {len(results_dict)} samples, skipped {skipped_count} samples."
    )
    return results_dict


def create_final_dataset(
    original_dataset: Dataset, all_embeddings: Dict[str, Dict[int, List[float]]]
) -> Dataset:
    """
    Creates a new dataset containing original data plus embedding columns for all models.
    It only includes samples that were successfully embedded across ALL models.
    """
    # Find the set of indices that have successful embeddings for every model
    if not all_embeddings:
        return Dataset.from_dict(
            {}
        )  # Return empty dataset if no embeddings generated

    common_indices = set(range(len(original_dataset)))
    for model_key, embeddings_dict in all_embeddings.items():
        model_indices = set(embeddings_dict.keys())
        common_indices.intersection_update(model_indices)

    print(
        f"Found {len(common_indices)} samples with successful embeddings for all models."
    )

    # Build the list of data for the new dataset
    data_list = []
    sorted_indices = sorted(list(common_indices))

    for i in sorted_indices:
        item = original_dataset[i]
        new_item = dict(item)
        for model_key, embeddings_dict in all_embeddings.items():
            new_item[f"embedding_{model_key}"] = embeddings_dict[i]
        data_list.append(new_item)

    return Dataset.from_list(data_list)


def process_dataset(
    dataset: Dataset, text_column: str, output_name: str, output_dir: str
):
    """
    Handles the full embedding and saving pipeline for a prepared dataset.
    """
    output_path = os.path.join(output_dir, output_name)
    if os.path.exists(output_path):
        print(f"Dataset already exists at '{output_path}'. Skipping.")
        return

    print(f"\n{'='*20} Processing dataset for '{output_name}' {'='*20}")
    texts_to_embed = dataset[text_column]

    all_embeddings = {}
    for model_key, model_name in EMBEDDING_MODELS.items():
        embeddings_dict = get_embeddings_ordered(
            texts_to_embed, model_name, model_key
        )
        all_embeddings[model_key] = embeddings_dict

    # Create the final dataset with all embedding columns
    embedded_dataset = create_final_dataset(dataset, all_embeddings)

    if len(embedded_dataset) > 0:
        print(
            f"Saving embedded dataset with {len(embedded_dataset)} rows to '{output_path}'"
        )
        embedded_dataset.save_to_disk(output_path)

        # Print summary
        columns = embedded_dataset.column_names
        embedding_columns = [
            col for col in columns if col.startswith("embedding_")
        ]
        print("Summary:")
        print(f"  - Total rows: {len(embedded_dataset)}")
        print(f"  - Total columns: {len(columns)}")
        print(
            f"  - Embedding columns: {len(embedding_columns)} ({', '.join(embedding_columns)})"
        )
    else:
        print(
            "No samples were successfully embedded across all models. No dataset will be saved."
        )


def main():
    """
    Main function to prepare, process, and embed the fiqa dataset.
    """
    os.makedirs(CACHE_DIR, exist_ok=True)
    print("--- Starting Fiqa Dataset Embedding Process ---")

    # Create output directory if it doesn't exist
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    # # --- 1. Prepare and Process Queries Dataset ---
    # print("\n--- Preparing Queries Dataset ---")
    # try:
    #     queries_ds = load_dataset(DATASET_NAME, name="queries", split="queries")
    #     relations_ds = load_dataset(DATASET_NAME, name="default", split="test")

    #     # Convert to pandas DataFrames
    #     queries_df = queries_ds.to_pandas()
    #     relations_df = relations_ds.to_pandas()

    #     print(relations_df["query-id"].nunique())
    #     print(queries_df["_id"].nunique())

    #     # Group relations by query-id and aggregate corpus-ids into a list
    #     relations_grouped_df = (
    #         relations_df.groupby("query-id")["corpus-id"]
    #         .apply(list)
    #         .reset_index()
    #         .rename(columns={"corpus-id": "relevant_corpus_ids"})
    #     )

    #     print(relations_grouped_df.shape)

    #     # Merge the original queries with the grouped relations
    #     merged_df = pd.merge(
    #         queries_df,
    #         relations_grouped_df,
    #         left_on="_id",
    #         right_on="query-id",
    #     )

    #     print(merged_df.shape)

    #     # Fill NaN values for queries with no matches with an empty list
    #     merged_df["relevant_corpus_ids"] = merged_df[
    #         "relevant_corpus_ids"
    #     ].apply(lambda d: d if isinstance(d, list) else [])

    #     # Drop the redundant 'query-id' column from the merge
    #     merged_df = merged_df.drop(columns=["query-id"])

    #     # Convert back to Hugging Face Dataset
    #     final_queries_ds = Dataset.from_pandas(merged_df)
    #     print(
    #         f"Successfully joined queries with relations. New dataset has {len(final_queries_ds)} rows and includes 'relevant_corpus_ids' column."
    #     )

    #     process_dataset(
    #         dataset=final_queries_ds,
    #         text_column="text",
    #         output_name="fiqa_queries_embedded",
    #         output_dir=OUTPUT_DIR,
    #     )

    # except Exception as e:
    #     print(f"Failed to process the queries dataset. Error: {e}")

    # --- 2. Prepare and Process Corpus Dataset ---
    print("\n--- Preparing Corpus Dataset ---")
    try:
        corpus_ds = load_dataset(DATASET_NAME, name="corpus", split="corpus")

        process_dataset(
            dataset=corpus_ds,
            text_column="text",
            output_name="fiqa_tools_embedded",
            output_dir=OUTPUT_DIR,
        )

    except Exception as e:
        print(f"Failed to process the corpus dataset. Error: {e}")

    print("\n--- All Fiqa embedding tasks completed! ---")


if __name__ == "__main__":
    main()
