import concurrent.futures
import os
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://api.openai.com/v1"
EMBEDDING_MODELS = {
    "large": "text-embedding-3-large",
}
DATASET_NAME = "mangopy/ToolRet-before-sample"
# All 35 subsets from the dataset
DATASET_SUBSETS = [
    "apibank",
    "apigen",
    "appbench",
    "autotools-food",
    "autotools-music",
    "autotools-weather",
    "craft-math-algebra",
    "craft-tabmwp",
    "craft-vqa",
    "gorilla-huggingface",
    "gorilla-pytorch",
    "gorilla-tensor",
    "gpt4tools",
    "gta",
    "metatool",
    "mnms",
    "restgpt-spotify",
    "restgpt-tmdb",
    "reversechain",
    "rotbench",
    "t-eval-dialog",
    "t-eval-step",
    "taskbench-daily",
    "taskbench-huggingface",
    "taskbench-multimedia",
    "tool-be-honest",
    "toolace",
    "toolalpaca",
    "toolbench",
    "toolbench-sam",
    "toolemu",
    "tooleyes",
    "toolink",
    "toollens",
    "ultratool",
]
# Target categories to filter by
TARGET_CATEGORIES = ["web", "code", "customized"]
EMBEDDING_DIM = 1536
OUTPUT_DIR = "embeddings"
# --- Concurrency Configuration ---
MAX_WORKERS = (
    16  # Adjust this based on your API rate limits and machine capabilities
)

# --- Create OpenAI Client ---
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)


def get_embedding(task_args):
    """
    Generates an embedding for a given text using the specified model.
    Includes retry logic and progressive truncation for handling potential API errors.
    Returns None if embedding fails (to be skipped later).
    """
    index, text, model, dimensions = task_args

    # Return None for empty text to skip this sample
    if not text:
        return index, None

    # Clean text
    current_text = text.replace("\n", " ").strip()

    max_retries = 5
    retry_delay = 5  # seconds
    length_error_occurred = False

    for attempt in range(max_retries):
        try:
            if model == EMBEDDING_MODELS["large"]:
                response = client.embeddings.create(
                    input=[current_text], model=model, dimensions=dimensions
                )
            else:
                response = client.embeddings.create(
                    input=[current_text], model=model
                )
            return index, response.data[0].embedding

        except Exception as e:
            error_message = str(e).lower()

            # Check if error is related to text length/tokens
            if any(
                keyword in error_message
                for keyword in ["token", "length", "too long", "maximum"]
            ):
                if not length_error_occurred:
                    print(f"Length-related error for sample {index}: {e}")
                    print(
                        f"Current text length: {len(current_text)} characters"
                    )
                    length_error_occurred = True

                # Try progressive truncation: 8192 -> 8000 -> give up
                if len(current_text) > 8192:
                    current_text = current_text[:8192]
                    print(
                        f"Sample {index}: Truncated to 8192 characters, retrying..."
                    )
                    continue
                elif len(current_text) > 8000:
                    current_text = current_text[:8000]
                    print(
                        f"Sample {index}: Truncated to 8000 characters, retrying..."
                    )
                    continue
                else:
                    print(
                        f"Sample {index}: Text still too long after truncation, skipping sample"
                    )
                    return index, None
            else:
                # For other errors, use normal retry logic
                if attempt == max_retries - 1:
                    print(
                        f"Sample {index}: Non-length error on final attempt: {e}"
                    )
                else:
                    print(
                        f"Sample {index}: Non-length error on attempt {attempt + 1}: {e}. Retrying..."
                    )
                time.sleep(retry_delay)

    print(
        f"Sample {index}: Failed to get embedding after {max_retries} retries, skipping sample"
    )
    return index, None


def get_embeddings_ordered(
    texts: List[str], model_name: str, model_key: str
) -> Dict[int, List[float]]:
    """
    Generate embeddings for a list of texts while maintaining order.
    Uses concurrent processing with proper ordering.
    Returns a dictionary mapping original indices to embeddings (skips failed samples).
    """
    print(
        f"\n--- Generating embeddings for model: {model_name} (using up to {MAX_WORKERS} workers) ---"
    )

    # Prepare arguments for each concurrent task with index to maintain order
    tasks = [
        (i, text, model_name, EMBEDDING_DIM) for i, text in enumerate(texts)
    ]

    # Dictionary to store results with their original indices
    results_dict = {}
    skipped_count = 0

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Use tqdm to create a progress bar as the tasks complete
        results_iterator = executor.map(get_embedding, tasks)
        for index, embedding in tqdm(
            results_iterator,
            total=len(texts),
            desc=f"Embedding with {model_key}",
        ):
            if embedding is not None:
                results_dict[index] = embedding
            else:
                skipped_count += 1

    print(
        f"Successfully embedded {len(results_dict)} samples, skipped {skipped_count} samples"
    )
    return results_dict


def load_and_concatenate_subsets():
    """
    Load all 35 subsets and concatenate them into one dataset.
    """
    print("--- Loading and concatenating all subsets ---")
    all_datasets = []

    for subset in DATASET_SUBSETS:
        print(f"Loading subset: {subset}")
        try:
            dataset = load_dataset(DATASET_NAME, name=subset, split="queries")
            print(f"  - Loaded {len(dataset)} samples from {subset}")
            all_datasets.append(dataset)
        except Exception as e:
            print(f"  - Failed to load subset '{subset}'. Error: {e}")
            continue

    if not all_datasets:
        raise ValueError("No datasets were successfully loaded!")

    # Concatenate all datasets
    concatenated_dataset = concatenate_datasets(all_datasets)
    print(
        f"Total concatenated dataset size: {len(concatenated_dataset)} samples"
    )

    return concatenated_dataset


def filter_by_categories(
    dataset: Dataset, categories: List[str]
) -> Dict[str, Dataset]:
    """
    Filter the dataset by categories and return a dictionary of category -> dataset.
    """
    print(f"--- Filtering dataset by categories: {categories} ---")
    category_datasets = {}

    for category in categories:
        filtered_dataset = dataset.filter(lambda x: x["category"] == category)
        category_datasets[category] = filtered_dataset
        print(f"Category '{category}': {len(filtered_dataset)} samples")

    return category_datasets


def create_embedding_text(item):
    """
    Concatenate instruction and query to create the text for embedding.
    """
    instruction = item.get("instruction", "")
    query = item.get("query", "")

    # Ensure both are strings
    if not isinstance(instruction, str):
        instruction = ""
    if not isinstance(query, str):
        query = ""

    # Concatenate with a space separator
    return f"{instruction}\n{query}".strip()


def main():
    """
    Main function to load datasets, generate embeddings for each model and category,
    and save the results as a dataset with separate categories.
    """
    print("--- Starting ToolRet-before-sample Dataset Embedding Process ---")

    # --- Create Output Directory ---
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    # Check if final dataset already exists
    final_output_path = os.path.join(
        OUTPUT_DIR, "toolret_queries_embedded_qwen"
    )
    if os.path.exists(final_output_path):
        print(
            f"\nFinal embedded dataset already exists at '{final_output_path}'. Skipping."
        )
        return

    # --- Load and Concatenate All Subsets ---
    concatenated_dataset = load_and_concatenate_subsets()

    # --- Filter by Categories ---
    category_datasets = filter_by_categories(
        concatenated_dataset, TARGET_CATEGORIES
    )

    # Dictionary to collect data by category
    category_data = {}

    # --- Process Each Category ---
    for category in TARGET_CATEGORIES:
        if category not in category_datasets:
            print(f"Category '{category}' not found in dataset. Skipping.")
            continue

        dataset = category_datasets[category]
        print(f"\n=== Processing category: {category} ===")
        print(f"Dataset size: {len(dataset)} samples")

        # --- Prepare all texts to be embedded ---
        texts_to_embed = []
        for item in dataset:
            embedding_text = create_embedding_text(item)
            texts_to_embed.append(embedding_text)

        # List to collect embedded data for this category
        category_embedded_data = []

        # --- Process and Embed for Each Model ---
        for model_key, model_name in EMBEDDING_MODELS.items():
            print(f"\n--- Processing {category} with {model_name} ---")

            # --- Generate Embeddings (maintaining order) ---
            embeddings_dict = get_embeddings_ordered(
                texts_to_embed, model_name, model_key
            )

            # --- Add data to category collection (only successful embeddings) ---
            for i, item in enumerate(dataset):
                if (
                    i in embeddings_dict
                ):  # Only include samples with successful embeddings
                    embedded_item = {
                        "id": item["id"],
                        "query": item["query"],
                        "instruction": item["instruction"],
                        "labels": item["labels"],
                        "category": item["category"],
                        "embed": embeddings_dict[i],
                        "embed_model": model_key,
                    }
                    category_embedded_data.append(embedded_item)

        # Create dataset for this category
        category_data[category] = Dataset.from_list(category_embedded_data)
        print(
            f"Created category '{category}' with {len(category_embedded_data)} rows"
        )

    # --- Create Final Dataset with Separate Categories ---
    print(
        f"\n--- Creating dataset with {len(category_data)} separate categories ---"
    )
    final_dataset = DatasetDict(category_data)

    # --- Save Final Dataset ---
    final_dataset.save_to_disk(final_output_path)
    print(f"Dataset with separate categories saved to '{final_output_path}'")

    # Print summary
    total_rows = sum(len(ds) for ds in category_data.values())
    print(f"Total rows across all categories: {total_rows}")
    print("Category breakdown:")
    for category_name, category_ds in category_data.items():
        print(f"  - {category_name}: {len(category_ds)} rows")
    print(
        f"Each category contains 7 columns (id, query, instruction, labels, category, embed, embed_model)"
    )
    print(
        f"Each category has data from {len(EMBEDDING_MODELS)} embedding models"
    )

    print("\n--- All embedding tasks completed successfully! ---")


if __name__ == "__main__":
    main()
