import json
import os
import numpy as np
import jsonlines
from openai import OpenAI
from tqdm import tqdm

# Resolve local directories relative to this script for portability.
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, ".."))

# Configuration
CACHE_DIR = os.path.join(SCRIPT_DIR, "cache")
EMBEDDING_FILE = os.path.join(CACHE_DIR, "query_embeddings_combined.json")
EMBEDDING_DIM = 1024

# Datasets to process
DATASETS = [
    {
        "path": os.path.join(
            PROJECT_ROOT,
            "ttsrouter-v1.1",
            "src",
            "envs",
            "MATH",
            "dataset",
            "test150.jsonl",
        ),
        "prefix": "test150_question",
    },
    {
        "path": os.path.join(
            PROJECT_ROOT,
            "ttsrouter-v1.1",
            "src",
            "envs",
            "MATH",
            "dataset",
            "test_aime.jsonl",
        ),
        "prefix": "test_aime_question",
    },
]

# Initialize OpenAI client
client = OpenAI(
    api_key="your_api_key_here",
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)

def get_embedding(text):
    """Use OpenAI API to get text embedding"""
    try:
        completion = client.embeddings.create(
            model="text-embedding-v4",
            input=[text],
            dimensions=EMBEDDING_DIM,
            encoding_format="float"
        )
        if hasattr(completion, 'data') and len(completion.data) > 0:
            return completion.data[0].embedding
        else:
            raise ValueError("Invalid API response format")
    except Exception as e:
        print(f"Embedding API failed for text: {text[:30]}... Error: {e}")
        # Return a random vector as a fallback (or choose to raise an exception)
        return np.random.randn(EMBEDDING_DIM).tolist()

def main():
    os.makedirs(CACHE_DIR, exist_ok=True)
    
    all_embeddings = {}
    
    # If the file already exists, load it first
    if os.path.exists(EMBEDDING_FILE):
        print(f"Loading existing embeddings from {EMBEDDING_FILE}")
        with open(EMBEDDING_FILE, 'r', encoding='utf-8') as f:
            all_embeddings = json.load(f)
            
    total_processed = 0
    
    for dataset_info in DATASETS:
        file_path = dataset_info["path"]
        prefix = dataset_info["prefix"]
        
        print(f"Processing {file_path}...")
        
        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue
            
        with jsonlines.open(file_path) as reader:
            # Convert to list to get indices
            items = list(reader)
            
            for i, item in enumerate(tqdm(items)):
                # Generate unique ID
                # test150 starts from 1, test_aime starts from 1
                query_id = f"{prefix}{i+1}"
                
                # Skip if already exists
                if query_id in all_embeddings:
                    continue
                
                query_text = item.get('problem', '')
                if not query_text:
                    print(f"Warning: Empty problem text for {query_id}")
                    continue
                
                embedding = get_embedding(query_text)
                all_embeddings[query_id] = embedding
                total_processed += 1
                
                # Save every 10 entries to prevent data loss on interruption
                if total_processed % 10 == 0:
                    with open(EMBEDDING_FILE, 'w', encoding='utf-8') as f:
                        json.dump(all_embeddings, f)
    
    # Final save
    with open(EMBEDDING_FILE, 'w', encoding='utf-8') as f:
        json.dump(all_embeddings, f)
        
    print(f"Done. Total embeddings: {len(all_embeddings)}. Saved to {EMBEDDING_FILE}")

if __name__ == "__main__":
    main()
