import concurrent.futures
import hashlib
import os
import pickle
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, load_from_disk
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("DASHSCOPE_API_KEY")
if not API_KEY:
    raise ValueError(
        "DASHSCOPE_API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
EMBEDDING_MODELS = {
    "large": "text-embedding-v4",
}
DATASET_PATH = "processed_datasets/ultratool_decomposed"
COLUMN_TO_EMBED = "prompt"  # Only embed the prompt column
EMBEDDING_DIM = 1536  # Only used for large model
OUTPUT_DIR = "embeddings"
CACHE_DIR = ".embedding_cache"
MAX_WORKERS = 10
BATCH_SIZE = 10
RATE_LIMIT_DELAY = 0.1  # seconds between requests

# Create output and cache directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)


def get_embedding(client: OpenAI, text: str, model: str) -> List[float]:
    """Get embedding for a single text."""
    try:
        response = client.embeddings.create(input=text, model=model)
        return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None


def get_embeddings_batch(
    client: OpenAI, texts: List[str], model: str
) -> List[List[float]]:
    """Get embeddings for a batch of texts with truncation handling and caching per item."""
    # Prepare results array aligned with original input order
    results: List[List[float] | None] = [None] * len(texts)

    # Determine dimensions for cache key consistency
    dimensions = EMBEDDING_DIM if model == EMBEDDING_MODELS["large"] else None

    # Build cache hits and misses
    to_query: List[str] = []
    to_query_indices: List[int] = []

    for i, text in enumerate(texts):
        if not isinstance(text, str) or not text:
            results[i] = None
            continue
        cache_key_str = f"{text}-{model}-{dimensions}"
        cache_key = hashlib.sha256(cache_key_str.encode("utf-8")).hexdigest()
        cache_path = os.path.join(CACHE_DIR, cache_key)
        if os.path.exists(cache_path):
            try:
                with open(cache_path, "rb") as f:
                    results[i] = pickle.load(f)
            except Exception as e:
                print(f"Warning: failed to load cache for item {i}: {e}")
                to_query.append(text)
                to_query_indices.append(i)
        else:
            to_query.append(text)
            to_query_indices.append(i)

    # If all were cache hits, return directly
    if not to_query:
        return results  # type: ignore[return-value]

    # Helper to call API for a list of texts
    def _call_api(inputs: List[str]) -> List[List[float]]:
        if dimensions is not None:
            response = client.embeddings.create(
                input=inputs, model=model, dimensions=dimensions
            )
        else:
            response = client.embeddings.create(input=inputs, model=model)
        return [d.embedding for d in response.data]

    # Try original inputs
    try:
        miss_embeddings = _call_api(to_query)
    except Exception as e:
        error_message = str(e).lower()
        if any(
            k in error_message
            for k in ["token", "length", "too long", "maximum"]
        ):
            print(f"Length-related error, trying with truncation: {e}")
            # First truncation to 8192
            try:
                truncated_8192 = [
                    t[:8192] if len(t) > 8192 else t for t in to_query
                ]
                miss_embeddings = _call_api(truncated_8192)
                print("Success with 8192 character truncation")
            except Exception as e2:
                print(f"8192 truncation failed: {e2}")
                # Second truncation to 8000
                try:
                    truncated_8000 = [
                        t[:8000] if len(t) > 8000 else t for t in to_query
                    ]
                    miss_embeddings = _call_api(truncated_8000)
                    print("Success with 8000 character truncation")
                except Exception as e3:
                    print(f"8000 truncation also failed: {e3}, skipping batch")
                    miss_embeddings = [None] * len(to_query)  # type: ignore[assignment]
        else:
            print(f"Non-length error getting batch embeddings: {e}")
            miss_embeddings = [None] * len(to_query)  # type: ignore[assignment]

    # Merge results back into original order and write cache for successes
    for orig_idx, emb in zip(to_query_indices, miss_embeddings):
        results[orig_idx] = emb
        if emb is not None:
            cache_key_str = f"{texts[orig_idx]}-{model}-{dimensions}"
            cache_key = hashlib.sha256(
                cache_key_str.encode("utf-8")
            ).hexdigest()
            cache_path = os.path.join(CACHE_DIR, cache_key)
            try:
                with open(cache_path, "wb") as f:
                    pickle.dump(emb, f)
            except Exception as e:
                print(
                    f"Warning: failed to write cache for item {orig_idx}: {e}"
                )

    return results  # type: ignore[return-value]


def process_batch(args):
    """Process a batch of texts and return embeddings."""
    texts, model, api_key, base_url = args
    client = OpenAI(api_key=api_key, base_url=base_url)

    embeddings = get_embeddings_batch(client, texts, model)
    time.sleep(RATE_LIMIT_DELAY)  # Rate limiting
    return embeddings


def embed_dataset_parallel(
    dataset: Dataset, model_name: str, model: str
) -> Dataset:
    """Embed the dataset using parallel processing."""
    print(f"Embedding dataset with {model_name} model ({model})...")

    # Get texts to embed (only the prompt column)
    texts = dataset[COLUMN_TO_EMBED]

    # Prepare batches
    batches = []
    for i in range(0, len(texts), BATCH_SIZE):
        batch_texts = texts[i : i + BATCH_SIZE]
        batches.append((batch_texts, model, API_KEY, BASE_URL))

    # Process batches in parallel while maintaining order
    all_embeddings = []
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Submit all batches and store futures with their indices
        future_to_index = {}
        for i, batch in enumerate(batches):
            future = executor.submit(process_batch, batch)
            future_to_index[future] = i

        # Collect results in order
        results = [None] * len(batches)
        with tqdm(
            total=len(batches), desc=f"Processing {model_name} batches"
        ) as pbar:
            for future in concurrent.futures.as_completed(future_to_index):
                batch_index = future_to_index[future]
                batch_embeddings = future.result()
                results[batch_index] = batch_embeddings
                pbar.update(1)

        # Flatten results in correct order
        for batch_embeddings in results:
            all_embeddings.extend(batch_embeddings)

    # Handle failed embeddings - create full list with None for failures
    final_embeddings = []
    for i, embedding in enumerate(all_embeddings):
        if embedding is not None:
            final_embeddings.append(embedding)
        else:
            print(f"Warning: Failed to get embedding for item {i}, using None")
            final_embeddings.append(None)

    print(
        f"Successfully got {len([e for e in final_embeddings if e is not None])} embeddings out of {len(final_embeddings)} items"
    )

    # Add embeddings to dataset
    dataset_with_embeddings = dataset.add_column(
        f"embedding_{model_name}", final_embeddings
    )

    print(
        f"Successfully embedded {len([e for e in final_embeddings if e is not None])} items with {model_name}"
    )
    return dataset_with_embeddings


def main():
    """Main function to embed the UltraTool decomposed dataset."""
    print("Loading UltraTool decomposed dataset...")
    dataset = load_from_disk(DATASET_PATH)
    print(f"Loaded dataset with {len(dataset)} items")

    # Print some statistics
    print(f"\nDataset info:")
    print(f"Columns: {dataset.column_names}")
    print(f"Number of items: {len(dataset)}")

    # Show sample data
    print(f"\nSample item:")
    sample = dataset[0]
    for key, value in sample.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        else:
            print(f"  {key}: {value}")

    # Show what we're embedding
    print(f"\nEmbedding column: {COLUMN_TO_EMBED}")
    print(f"Sample text to embed: {sample[COLUMN_TO_EMBED][:200]}...")

    # Embed with each model
    embedded_dataset = dataset
    for model_name, model in EMBEDDING_MODELS.items():
        print(f"\n{'='*50}")
        embedded_dataset = embed_dataset_parallel(
            embedded_dataset, model_name, model
        )

    # Save the embedded dataset
    output_path = os.path.join(OUTPUT_DIR, "ultratool_queries_embedded_qwen")
    print(f"\nSaving embedded dataset to {output_path}...")
    embedded_dataset.save_to_disk(output_path)

    print(f"\nEmbedding complete!")
    print(f"Final dataset columns: {embedded_dataset.column_names}")
    print(f"Dataset saved to: {output_path}")

    # Verify embeddings
    print(f"\nVerifying embeddings...")
    for model_name in EMBEDDING_MODELS.keys():
        embedding_col = f"embedding_{model_name}"
        if embedding_col in embedded_dataset.column_names:
            sample_embedding = embedded_dataset[0][embedding_col]
            print(f"  {model_name}: {len(sample_embedding)} dimensions")
        else:
            print(f"  {model_name}: Missing!")


if __name__ == "__main__":
    main()
