import concurrent.futures
import json
import os
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, load_from_disk
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://api.openai.com/v1"
EMBEDDING_MODELS = {
    "ada": "text-embedding-ada-002",
    "small": "text-embedding-3-small",
    "large": "text-embedding-3-large",
}
DATASET_PATH = "processed_datasets/ultratool_tools"
EMBEDDING_DIM = 1536  # Only used for large model
OUTPUT_DIR = "embeddings"
MAX_WORKERS = 10
BATCH_SIZE = 100
RATE_LIMIT_DELAY = 0.1  # seconds between requests

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)


def create_tool_text_for_embedding(tool: Dict[str, Any]) -> str:
    """Create text representation of tool for embedding - use all tool info in order"""
    # Create a simple text representation with all tool information
    parts = []

    # Add name
    if tool["name"]:
        parts.append(f"Name: {tool['name']}")

    # Add description
    if tool["description"]:
        parts.append(f"Description: {tool['description']}")

    # Add arguments info
    if tool["arguments"]:
        try:
            args = json.loads(tool["arguments"])
            if args:
                parts.append(
                    f"Arguments: {json.dumps(args, ensure_ascii=False)}"
                )
        except:
            parts.append(f"Arguments: {tool['arguments']}")

    # Add results info
    if tool["results"]:
        try:
            results = json.loads(tool["results"])
            if results:
                parts.append(
                    f"Results: {json.dumps(results, ensure_ascii=False)}"
                )
        except:
            parts.append(f"Results: {tool['results']}")

    return "\n".join(parts)


def get_embedding(client: OpenAI, text: str, model: str) -> List[float]:
    """Get embedding for a single text."""
    try:
        response = client.embeddings.create(input=text, model=model)
        return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None


def get_embeddings_batch(
    client: OpenAI, texts: List[str], model: str
) -> List[List[float]]:
    """Get embeddings for a batch of texts with truncation handling."""
    # First try with original texts
    try:
        if model == EMBEDDING_MODELS["large"]:
            response = client.embeddings.create(
                input=texts, model=model, dimensions=EMBEDDING_DIM
            )
        else:
            response = client.embeddings.create(input=texts, model=model)
        return [data.embedding for data in response.data]
    except Exception as e:
        error_message = str(e).lower()

        # Check if error is related to text length/tokens
        if any(
            keyword in error_message
            for keyword in ["token", "length", "too long", "maximum"]
        ):
            print(f"Length-related error, trying with truncation: {e}")

            # Try truncating to 8192 characters first
            truncated_texts_8192 = [
                text[:8192] if len(text) > 8192 else text for text in texts
            ]
            try:
                if model == EMBEDDING_MODELS["large"]:
                    response = client.embeddings.create(
                        input=truncated_texts_8192,
                        model=model,
                        dimensions=EMBEDDING_DIM,
                    )
                else:
                    response = client.embeddings.create(
                        input=truncated_texts_8192, model=model
                    )
                print(f"Success with 8192 character truncation")
                return [data.embedding for data in response.data]
            except Exception as e2:
                print(f"8192 truncation failed: {e2}")

                # Try truncating to 8000 characters
                truncated_texts_8000 = [
                    text[:8000] if len(text) > 8000 else text for text in texts
                ]
                try:
                    if model == EMBEDDING_MODELS["large"]:
                        response = client.embeddings.create(
                            input=truncated_texts_8000,
                            model=model,
                            dimensions=EMBEDDING_DIM,
                        )
                    else:
                        response = client.embeddings.create(
                            input=truncated_texts_8000, model=model
                        )
                    print(f"Success with 8000 character truncation")
                    return [data.embedding for data in response.data]
                except Exception as e3:
                    print(f"8000 truncation also failed: {e3}, skipping batch")
                    return [None] * len(texts)
        else:
            print(f"Non-length error getting batch embeddings: {e}")
            return [None] * len(texts)


def process_batch(args):
    """Process a batch of texts and return embeddings."""
    texts, model, api_key, base_url = args
    client = OpenAI(api_key=api_key, base_url=base_url)

    embeddings = get_embeddings_batch(client, texts, model)
    time.sleep(RATE_LIMIT_DELAY)  # Rate limiting
    return embeddings


def embed_dataset_parallel(
    dataset: Dataset, model_name: str, model: str
) -> Dataset:
    """Embed the dataset using parallel processing."""
    print(f"Embedding dataset with {model_name} model ({model})...")

    # Create text representations for each tool (only if not already done)
    if "text_representation" not in dataset.column_names:
        print("Creating tool text representations...")
        tool_texts = []
        for item in tqdm(dataset, desc="Creating texts"):
            text = create_tool_text_for_embedding(item)
            tool_texts.append(text)
    else:
        print("Using existing text representations...")
        tool_texts = dataset["text_representation"]

    # Prepare batches
    batches = []
    for i in range(0, len(tool_texts), BATCH_SIZE):
        batch_texts = tool_texts[i : i + BATCH_SIZE]
        batches.append((batch_texts, model, API_KEY, BASE_URL))

    # Process batches in parallel while maintaining order
    all_embeddings = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Submit all batches and store futures with their indices
        future_to_index = {}
        for i, batch in enumerate(batches):
            future = executor.submit(process_batch, batch)
            future_to_index[future] = i

        # Collect results in order
        results = [None] * len(batches)
        with tqdm(
            total=len(batches), desc=f"Processing {model_name} batches"
        ) as pbar:
            for future in concurrent.futures.as_completed(future_to_index):
                batch_index = future_to_index[future]
                batch_embeddings = future.result()
                results[batch_index] = batch_embeddings
                pbar.update(1)

        # Flatten results in correct order
        for batch_embeddings in results:
            all_embeddings.extend(batch_embeddings)

    # Handle failed embeddings - create full list with None for failures
    final_embeddings = []
    for i, embedding in enumerate(all_embeddings):
        if embedding is not None:
            final_embeddings.append(embedding)
        else:
            print(f"Warning: Failed to get embedding for item {i}, using None")
            final_embeddings.append(None)

    print(
        f"Successfully got {len([e for e in final_embeddings if e is not None])} embeddings out of {len(final_embeddings)} items"
    )

    # Add text representation column only if it doesn't exist
    if "text_representation" not in dataset.column_names:
        dataset_with_text = dataset.add_column(
            "text_representation", tool_texts
        )
    else:
        dataset_with_text = dataset

    # Add embeddings to dataset
    dataset_with_embeddings = dataset_with_text.add_column(
        f"embedding_{model_name}", final_embeddings
    )

    print(
        f"Successfully embedded {len([e for e in final_embeddings if e is not None])} items with {model_name}"
    )
    return dataset_with_embeddings


def main():
    """Main function to embed the UltraTool tools dataset."""
    print("Loading UltraTool tools dataset...")
    dataset = load_from_disk(DATASET_PATH)
    print(f"Loaded dataset with {len(dataset)} items")

    # Print some statistics
    print(f"\nDataset info:")
    print(f"Columns: {dataset.column_names}")
    print(f"Number of items: {len(dataset)}")

    # Show sample data
    print(f"\nSample tool:")
    sample = dataset[0]
    for key, value in sample.items():
        if isinstance(value, dict):
            print(f"  {key}: {json.dumps(value, indent=2)[:200]}...")
        elif isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        else:
            print(f"  {key}: {value}")

    # Show sample text representation
    print(f"\nSample text representation:")
    sample_text = create_tool_text_for_embedding(sample)
    print(sample_text[:300] + "..." if len(sample_text) > 300 else sample_text)

    # Embed with each model
    embedded_dataset = dataset
    for model_name, model in EMBEDDING_MODELS.items():
        print(f"\n{'='*50}")
        embedded_dataset = embed_dataset_parallel(
            embedded_dataset, model_name, model
        )

    # Save the embedded dataset
    output_path = os.path.join(OUTPUT_DIR, "ultratool_tools_embedded")
    print(f"\nSaving embedded dataset to {output_path}...")
    embedded_dataset.save_to_disk(output_path)

    print(f"\nEmbedding complete!")
    print(f"Final dataset columns: {embedded_dataset.column_names}")
    print(f"Dataset saved to: {output_path}")

    # Verify embeddings
    print(f"\nVerifying embeddings...")
    for model_name in EMBEDDING_MODELS.keys():
        embedding_col = f"embedding_{model_name}"
        if embedding_col in embedded_dataset.column_names:
            sample_embedding = embedded_dataset[0][embedding_col]
            print(f"  {model_name}: {len(sample_embedding)} dimensions")
        else:
            print(f"  {model_name}: Missing!")


if __name__ == "__main__":
    main()
