import concurrent.futures
import os
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, load_from_disk
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://api.openai.com/v1"
EMBEDDING_MODELS = {
    "ada": "text-embedding-ada-002",
    "small": "text-embedding-3-small",
    "large": "text-embedding-3-large",
}
DATASET_PATH = "processed_datasets/ultratool_decomposed"
COLUMN_TO_EMBED = "prompt"  # Only embed the prompt column
EMBEDDING_DIM = 1536  # Only used for large model
OUTPUT_DIR = "embeddings"
MAX_WORKERS = 10
BATCH_SIZE = 100
RATE_LIMIT_DELAY = 0.1  # seconds between requests

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)


def get_embedding(client: OpenAI, text: str, model: str) -> List[float]:
    """Get embedding for a single text."""
    try:
        response = client.embeddings.create(input=text, model=model)
        return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None


def get_embeddings_batch(
    client: OpenAI, texts: List[str], model: str
) -> List[List[float]]:
    """Get embeddings for a batch of texts with truncation handling."""
    # First try with original texts
    try:
        if model == EMBEDDING_MODELS["large"]:
            response = client.embeddings.create(
                input=texts, model=model, dimensions=EMBEDDING_DIM
            )
        else:
            response = client.embeddings.create(input=texts, model=model)
        return [data.embedding for data in response.data]
    except Exception as e:
        error_message = str(e).lower()

        # Check if error is related to text length/tokens
        if any(
            keyword in error_message
            for keyword in ["token", "length", "too long", "maximum"]
        ):
            print(f"Length-related error, trying with truncation: {e}")

            # Try truncating to 8192 characters first
            truncated_texts_8192 = [
                text[:8192] if len(text) > 8192 else text for text in texts
            ]
            try:
                if model == EMBEDDING_MODELS["large"]:
                    response = client.embeddings.create(
                        input=truncated_texts_8192,
                        model=model,
                        dimensions=EMBEDDING_DIM,
                    )
                else:
                    response = client.embeddings.create(
                        input=truncated_texts_8192, model=model
                    )
                print(f"Success with 8192 character truncation")
                return [data.embedding for data in response.data]
            except Exception as e2:
                print(f"8192 truncation failed: {e2}")

                # Try truncating to 8000 characters
                truncated_texts_8000 = [
                    text[:8000] if len(text) > 8000 else text for text in texts
                ]
                try:
                    if model == EMBEDDING_MODELS["large"]:
                        response = client.embeddings.create(
                            input=truncated_texts_8000,
                            model=model,
                            dimensions=EMBEDDING_DIM,
                        )
                    else:
                        response = client.embeddings.create(
                            input=truncated_texts_8000, model=model
                        )
                    print(f"Success with 8000 character truncation")
                    return [data.embedding for data in response.data]
                except Exception as e3:
                    print(f"8000 truncation also failed: {e3}, skipping batch")
                    return [None] * len(texts)
        else:
            print(f"Non-length error getting batch embeddings: {e}")
            return [None] * len(texts)


def process_batch(args):
    """Process a batch of texts and return embeddings."""
    texts, model, api_key, base_url = args
    client = OpenAI(api_key=api_key, base_url=base_url)

    embeddings = get_embeddings_batch(client, texts, model)
    time.sleep(RATE_LIMIT_DELAY)  # Rate limiting
    return embeddings


def embed_dataset_parallel(
    dataset: Dataset, model_name: str, model: str
) -> Dataset:
    """Embed the dataset using parallel processing."""
    print(f"Embedding dataset with {model_name} model ({model})...")

    # Get texts to embed (only the prompt column)
    texts = dataset[COLUMN_TO_EMBED]

    # Prepare batches
    batches = []
    for i in range(0, len(texts), BATCH_SIZE):
        batch_texts = texts[i : i + BATCH_SIZE]
        batches.append((batch_texts, model, API_KEY, BASE_URL))

    # Process batches in parallel while maintaining order
    all_embeddings = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Submit all batches and store futures with their indices
        future_to_index = {}
        for i, batch in enumerate(batches):
            future = executor.submit(process_batch, batch)
            future_to_index[future] = i

        # Collect results in order
        results = [None] * len(batches)
        with tqdm(
            total=len(batches), desc=f"Processing {model_name} batches"
        ) as pbar:
            for future in concurrent.futures.as_completed(future_to_index):
                batch_index = future_to_index[future]
                batch_embeddings = future.result()
                results[batch_index] = batch_embeddings
                pbar.update(1)

        # Flatten results in correct order
        for batch_embeddings in results:
            all_embeddings.extend(batch_embeddings)

    # Handle failed embeddings - create full list with None for failures
    final_embeddings = []
    for i, embedding in enumerate(all_embeddings):
        if embedding is not None:
            final_embeddings.append(embedding)
        else:
            print(f"Warning: Failed to get embedding for item {i}, using None")
            final_embeddings.append(None)

    print(
        f"Successfully got {len([e for e in final_embeddings if e is not None])} embeddings out of {len(final_embeddings)} items"
    )

    # Add embeddings to dataset
    dataset_with_embeddings = dataset.add_column(
        f"embedding_{model_name}", final_embeddings
    )

    print(
        f"Successfully embedded {len([e for e in final_embeddings if e is not None])} items with {model_name}"
    )
    return dataset_with_embeddings


def main():
    """Main function to embed the UltraTool decomposed dataset."""
    print("Loading UltraTool decomposed dataset...")
    dataset = load_from_disk(DATASET_PATH)
    print(f"Loaded dataset with {len(dataset)} items")

    # Print some statistics
    print(f"\nDataset info:")
    print(f"Columns: {dataset.column_names}")
    print(f"Number of items: {len(dataset)}")

    # Show sample data
    print(f"\nSample item:")
    sample = dataset[0]
    for key, value in sample.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        else:
            print(f"  {key}: {value}")

    # Show what we're embedding
    print(f"\nEmbedding column: {COLUMN_TO_EMBED}")
    print(f"Sample text to embed: {sample[COLUMN_TO_EMBED][:200]}...")

    # Embed with each model
    embedded_dataset = dataset
    for model_name, model in EMBEDDING_MODELS.items():
        print(f"\n{'='*50}")
        embedded_dataset = embed_dataset_parallel(
            embedded_dataset, model_name, model
        )

    # Save the embedded dataset
    output_path = os.path.join(OUTPUT_DIR, "ultratool_queries_embedded")
    print(f"\nSaving embedded dataset to {output_path}...")
    embedded_dataset.save_to_disk(output_path)

    print(f"\nEmbedding complete!")
    print(f"Final dataset columns: {embedded_dataset.column_names}")
    print(f"Dataset saved to: {output_path}")

    # Verify embeddings
    print(f"\nVerifying embeddings...")
    for model_name in EMBEDDING_MODELS.keys():
        embedding_col = f"embedding_{model_name}"
        if embedding_col in embedded_dataset.column_names:
            sample_embedding = embedded_dataset[0][embedding_col]
            print(f"  {model_name}: {len(sample_embedding)} dimensions")
        else:
            print(f"  {model_name}: Missing!")


if __name__ == "__main__":
    main()
