import concurrent.futures
import hashlib
import json
import os
import pickle
import time
from typing import Any, Dict, List

import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
from openai import OpenAI
from tqdm import tqdm

# --- Configuration ---
# IMPORTANT: Set your API key in your environment variables, e.g., export API_KEY='your_api_key'
API_KEY = os.environ.get("OPENAI_API_KEY")
if not API_KEY:
    raise ValueError(
        "API_KEY environment variable not set. Please set it before running the script."
    )

BASE_URL = "https://api.openai.com/v1"
EMBEDDING_MODELS = {
    "large": "text-embedding-3-large",
}
DATASET_NAME = "mangopy/ToolRet-Tools"
DATASET_SUBSETS = ["code", "customized", "web"]  # Process all subsets
COLUMN_TO_EMBED = "documentation"
CACHE_DIR = ".embedding_cache"

EMBEDDING_DIM = 1536
OUTPUT_DIR = "embeddings"
# --- Concurrency Configuration ---
MAX_WORKERS = (
    64  # Adjust this based on your API rate limits and machine capabilities
)

# --- Create OpenAI Client ---
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)

# --- Schema Configuration ---
SUBSET_SCHEMAS = {
    "code": [
        "name",
        "description",
        "func_description",
        "functionality",
    ],
    "web": ["name", "description"],
    "customized": ["name", "description"],
}


def extract_fields_from_documentation(doc_text: str, subset: str) -> str:
    """
    Extract and concatenate specific fields from JSON documentation based on subset schema.

    Args:
        doc_text: The documentation string (should be JSON)
        subset: The subset name ("code", "web", or "customized")

    Returns:
        Concatenated text from specified fields
    """
    if subset not in SUBSET_SCHEMAS:
        print(f"Warning: Unknown subset '{subset}', using raw documentation")
        return doc_text

    try:
        # Parse JSON documentation
        doc_dict = json.loads(doc_text)
        if not isinstance(doc_dict, dict):
            print(f"Warning: Documentation is not a JSON dict, using raw text")
            return doc_text

        # Extract specified fields in order
        extracted_parts = []
        schema_fields = SUBSET_SCHEMAS[subset]

        for field in schema_fields:
            if field in doc_dict:
                field_value = doc_dict[field]
                # Convert to string if not already
                if isinstance(field_value, str):
                    if field_value.strip():  # Only add non-empty strings
                        extracted_parts.append(field_value.strip())
                elif field_value is not None:
                    # Convert other types to string
                    str_value = str(field_value).strip()
                    if str_value and str_value != "None":
                        extracted_parts.append(str_value)

        # Concatenate with space separators
        if extracted_parts:
            return "\n".join(extracted_parts)
        else:
            # Fallback to raw documentation if no fields found
            print(
                f"Warning: No valid fields found in documentation, using raw text"
            )
            return doc_text

    except json.JSONDecodeError as e:
        print(
            f"Warning: Failed to parse JSON documentation: {e}, using raw text"
        )
        return doc_text
    except Exception as e:
        print(
            f"Warning: Unexpected error processing documentation: {e}, using raw text"
        )
        return doc_text


def truncate_text_progressively(text: str) -> str:
    """
    Progressively truncate text: first to 8192 chars, then to 8000 chars.

    Args:
        text: The text to truncate

    Returns:
        Truncated text or None if text is too short after truncation
    """
    if len(text) <= 8000:
        return text

    # First try 8192 characters
    if len(text) > 8192:
        truncated = text[:8192]
        print(f"Truncated text from {len(text)} to 8192 characters")
        return truncated

    # Then try 8000 characters
    if len(text) > 8000:
        truncated = text[:8000]
        print(f"Truncated text from {len(text)} to 8000 characters")
        return truncated

    return text


def get_embedding(task_args):
    """
    Generates an embedding for a given text using the specified model.
    Includes retry logic and progressive truncation for handling potential API errors.
    Uses a per-item cache keyed by (text, model, dimensions).
    Returns None if embedding fails (to be skipped later).
    """
    index, text, model, dimensions = task_args

    # Return None for empty or invalid text to skip this sample
    if not text or not isinstance(text, str):
        return index, None

    # Build cache key
    cache_key_str = f"{text}-{model}-{dimensions}"
    cache_key = hashlib.sha256(cache_key_str.encode("utf-8")).hexdigest()
    cache_path = os.path.join(CACHE_DIR, cache_key)

    # Check cache first
    if os.path.exists(cache_path):
        try:
            with open(cache_path, "rb") as f:
                cached_embedding = pickle.load(f)
            return index, cached_embedding
        except Exception as e:
            print(f"Warning: failed to load cache for index {index}: {e}")

    # Clean text
    current_text = text.replace("\n", " ").strip()
    if not current_text:
        return index, None

    max_retries = 5
    retry_delay = 5  # seconds
    length_error_occurred = False

    for attempt in range(max_retries):
        try:
            if model == EMBEDDING_MODELS["large"]:
                response = client.embeddings.create(
                    input=[current_text], model=model, dimensions=dimensions
                )
            else:
                response = client.embeddings.create(
                    input=[current_text], model=model
                )
            embedding = response.data[0].embedding

            # Write to cache
            try:
                with open(cache_path, "wb") as f:
                    pickle.dump(embedding, f)
            except Exception as e:
                print(f"Warning: failed to write cache for index {index}: {e}")

            return index, embedding

        except Exception as e:
            error_message = str(e).lower()

            # Check if error is related to text length/tokens
            if any(
                keyword in error_message
                for keyword in ["token", "length", "too long", "maximum"]
            ):
                if not length_error_occurred:
                    print(f"Length-related error for sample {index}: {e}")
                    print(
                        f"Current text length: {len(current_text)} characters"
                    )
                    length_error_occurred = True

                # Try progressive truncation: 8192 -> 8000 -> give up
                if len(current_text) > 8192:
                    current_text = current_text[:8192]
                    print(
                        f"Sample {index}: Truncated to 8192 characters, retrying..."
                    )
                    continue
                elif len(current_text) > 8000:
                    current_text = current_text[:8000]
                    print(
                        f"Sample {index}: Truncated to 8000 characters, retrying..."
                    )
                    continue
                else:
                    print(
                        f"Sample {index}: Text still too long after truncation, skipping sample"
                    )
                    return index, None
            else:
                # For other errors, use normal retry logic
                if attempt == max_retries - 1:
                    print(
                        f"Sample {index}: Non-length error on final attempt: {e}"
                    )
                else:
                    print(
                        f"Sample {index}: Non-length error on attempt {attempt + 1}: {e}. Retrying..."
                    )
                time.sleep(retry_delay)

    print(
        f"Sample {index}: Failed to get embedding after {max_retries} retries, skipping sample"
    )
    return index, None


def get_embeddings_ordered(
    texts: List[str], model_name: str, model_key: str
) -> Dict[int, List[float]]:
    """
    Generate embeddings for a list of texts while maintaining order.
    Uses concurrent processing with proper ordering.
    Returns a dictionary mapping original indices to embeddings (skips failed samples).
    """
    print(
        f"\n--- Generating embeddings for model: {model_name} (using up to {MAX_WORKERS} workers) ---"
    )

    # Prepare arguments for each concurrent task with index to maintain order
    tasks = [
        (i, text, model_name, EMBEDDING_DIM) for i, text in enumerate(texts)
    ]

    # Dictionary to store results with their original indices
    results_dict = {}
    skipped_count = 0

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_WORKERS
    ) as executor:
        # Use tqdm to create a progress bar as the tasks complete
        results_iterator = executor.map(get_embedding, tasks)
        for index, embedding in tqdm(
            results_iterator,
            total=len(texts),
            desc=f"Embedding with {model_key}",
        ):
            if embedding is not None:
                results_dict[index] = embedding
            else:
                skipped_count += 1

    print(
        f"Successfully embedded {len(results_dict)} samples, skipped {skipped_count} samples"
    )
    return results_dict


def create_embedded_dataset(
    original_dataset: Dataset,
    embeddings_dict: Dict[int, List[float]],
    model_key: str,
) -> Dataset:
    """
    Create a new dataset with original data plus embedding columns.
    Only includes samples that have successful embeddings.
    """
    # Convert original dataset to list of dictionaries, only for successful embeddings
    data_list = []
    for i, item in enumerate(original_dataset):
        if (
            i in embeddings_dict
        ):  # Only include samples with successful embeddings
            new_item = {
                "id": item["id"],
                "documentation": item["documentation"],
                "embed": embeddings_dict[i],
                "embed_model": model_key,
            }
            data_list.append(new_item)

    # Create new dataset
    return Dataset.from_list(data_list)


def main():
    """
    Main function to load datasets, generate embeddings for each model and subset,
    and save the results as a dataset with separate subsets.
    """
    print("--- Starting Dataset Embedding Process ---")
    # Ensure cache directory exists
    os.makedirs(CACHE_DIR, exist_ok=True)

    # --- Create Output Directory ---
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    # Check if final dataset already exists
    final_output_path = os.path.join(OUTPUT_DIR, "toolret_tools_embedded")
    if os.path.exists(final_output_path):
        print(
            f"\nFinal embedded dataset already exists at '{final_output_path}'. Skipping."
        )
        return

    # Dictionary to collect data by subset
    subset_data = {}

    # --- Process Each Subset ---
    for subset in DATASET_SUBSETS:
        print(f"\n=== Processing subset: {subset} ===")

        # --- Load Dataset ---
        print(f"Loading dataset: {DATASET_NAME}, subset: {subset}")
        try:
            dataset = load_dataset(DATASET_NAME, name=subset, split="tools")
            print(f"Dataset loaded successfully with {len(dataset)} samples.")
        except Exception as e:
            print(f"Failed to load dataset subset '{subset}'. Error: {e}")
            continue

        # --- Prepare all texts to be embedded with schema-aware preprocessing ---
        texts_to_embed = []
        preprocessing_stats = {"successful": 0, "fallback": 0, "empty": 0}

        for i, item in enumerate(dataset):
            doc_text = item[COLUMN_TO_EMBED]

            # Ensure text is a string, default to empty string if not
            if not isinstance(doc_text, str):
                doc_text = ""
                preprocessing_stats["empty"] += 1

            if doc_text:
                # Extract fields based on subset schema
                try:
                    processed_text = extract_fields_from_documentation(
                        doc_text, subset
                    )

                    # Check if processing was successful (not just fallback)
                    if processed_text != doc_text:
                        preprocessing_stats["successful"] += 1
                    else:
                        preprocessing_stats["fallback"] += 1

                    texts_to_embed.append(processed_text)

                    # Log first few examples for verification
                    if i < 3:
                        print(f"\nExample {i+1} preprocessing:")
                        print(f"  Original length: {len(doc_text)} chars")
                        print(
                            f"  Processed length: {len(processed_text)} chars"
                        )
                        print(
                            f"  Processed text preview: {processed_text[:200]}..."
                        )

                except Exception as e:
                    print(f"Error preprocessing item {i}: {e}, using raw text")
                    texts_to_embed.append(doc_text)
                    preprocessing_stats["fallback"] += 1
            else:
                texts_to_embed.append("")
                preprocessing_stats["empty"] += 1

        print(f"\nPreprocessing summary for {subset}:")
        print(
            f"  - Successfully processed: {preprocessing_stats['successful']}"
        )
        print(f"  - Fallback to raw text: {preprocessing_stats['fallback']}")
        print(f"  - Empty documents: {preprocessing_stats['empty']}")
        print(f"  - Total: {len(texts_to_embed)}")

        # List to collect embedded data for this subset
        subset_embedded_data = []

        # --- Process and Embed for Each Model ---
        for model_key, model_name in EMBEDDING_MODELS.items():
            print(f"\n--- Processing {subset} with {model_name} ---")

            # --- Generate Embeddings (maintaining order) ---
            embeddings_dict = get_embeddings_ordered(
                texts_to_embed, model_name, model_key
            )

            # --- Add data to subset collection (only successful embeddings) ---
            for i, item in enumerate(dataset):
                if (
                    i in embeddings_dict
                ):  # Only include samples with successful embeddings
                    embedded_item = {
                        "id": item["id"],
                        "documentation": item["documentation"],
                        "embed": embeddings_dict[i],
                        "embed_model": model_key,
                    }
                    subset_embedded_data.append(embedded_item)

        # Create dataset for this subset
        subset_data[subset] = Dataset.from_list(subset_embedded_data)
        print(
            f"Created subset '{subset}' with {len(subset_embedded_data)} rows"
        )

    # --- Create Final Dataset with Separate Subsets ---

    print(
        f"\n--- Creating dataset with {len(subset_data)} separate subsets ---"
    )
    final_dataset = DatasetDict(subset_data)

    # --- Save Final Dataset ---
    final_dataset.save_to_disk(final_output_path)
    print(f"Dataset with separate subsets saved to '{final_output_path}'")

    # Print summary
    total_rows = sum(len(ds) for ds in subset_data.values())
    print(f"Total rows across all subsets: {total_rows}")
    print("Subset breakdown:")
    for subset_name, subset_ds in subset_data.items():
        print(f"  - {subset_name}: {len(subset_ds)} rows")
    print(
        f"Each subset contains 4 columns (id, documentation, embed, embed_model)"
    )
    print(f"Each subset has data from {len(EMBEDDING_MODELS)} embedding models")

    print("\n--- All embedding tasks completed successfully! ---")


if __name__ == "__main__":
    main()
