import json
import csv
import concurrent.futures
import sys
sys.path.append("./")
from utils.gpt import generate

category_list = [
    "Website Design",
    "Game Development",
    "Clone Development",
    "App Development",
    "Web Development",
    "UI Design",
    "Multilingual Queries",
    "Digital Tools",
    "App Design",
    "AI Applications",
    "Simulations",
    "Creative Humor"
]

classify_prompt = """
You are a query classifier. Your task is to classify a given user query into exactly one of the predefined categories.

Categories:
- Website Design: Queries about designing, building, or improving websites and web pages
- Game Development: Queries related to creating games, game mechanics, or game engines
- Clone Development: Requests to create clones of existing applications or websites
- App Development: Queries about building mobile or desktop applications
- Web Development: General web development queries including backend, frontend, APIs
- UI Design: Queries focusing on user interface design and components
- Multilingual Queries: Requests involving non-English languages
- Digital Tools: Queries about digital tools, utilities, or productivity software
- App Design: Queries specifically about the design aspects of applications
- AI Applications: Queries related to implementing AI or machine learning features
- Simulations: Requests for simulating processes, events, or systems
- Creative Humor: Queries with a creative or humorous intent

User Query: {query}

Based on this user query, classify it into exactly one of the above categories. Respond with only the category name, nothing else. 
"""

with open("data/full.jsonl", "r") as f:
    data = [json.loads(line) for line in f]

queries = [" ".join([i['content'][0]["text"] for i in item['conversation_a'] if i['role'] == 'user']) for item in data]

# for i, item in enumerate(queries):
#     queries[i] = item.replace("clone", "").replace("Clone", "")

question_ids = [item['question_id'] for item in data]

# Function to process a single query
def process_query(args):
    index, question_id, query = args
    # Fill the prompt template with the current query
    filled_prompt = classify_prompt.format(query=query)
    prompt = [{"role": "user", "content": filled_prompt}]
    # Call generate function to classify the query
    model_response, metadata = generate(model="gpt-4o", messages=prompt)
    # Clean up response to get just the category
    category = model_response.strip()
    print(f"Processed {index+1}/{len(queries)}: {question_id} - {category}")
    
    return {
        "question_id": question_id,
        "model_response": model_response,
        "category": category,
        "metadata": metadata
    }

# Generate category for each query in parallel
total_tokens = {
    "prompt_token_count": 0,
    "candidates_token_count": 0,
    "thoughts_token_count": 0
}

# Prepare arguments for parallel processing
process_args = [(i, qid, query) for i, (qid, query) in enumerate(zip(question_ids, queries))]

# Process queries in parallel
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
    future_to_query = {executor.submit(process_query, args): args for args in process_args}
    
    for future in concurrent.futures.as_completed(future_to_query):
        try:
            result = future.result()
            results.append({
                "question_id": result["question_id"],
                "model_response": result["model_response"],
                "category": result["category"]
            })
            
            # Update token counts
            metadata = result["metadata"]
            total_tokens["prompt_token_count"] += metadata["prompt_token_count"]
            total_tokens["candidates_token_count"] += metadata["candidates_token_count"]
            total_tokens["thoughts_token_count"] += metadata["thoughts_token_count"]
        except Exception as e:
            print(f"Error processing query: {e}")

# Save results to CSV
output_file = "data/statistics/query_categories.csv"
with open(output_file, "w", newline="", encoding="utf-8") as csvfile:
    fieldnames = ["question_id", "model_response", "category"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(results)

print(f"Results saved to {output_file}")
print(total_tokens)