import json
import sys
import concurrent.futures
sys.path.append(".")
from prompts.rubric_tree_generate import RUBRIC_GENERATION_PROMPT
from utils.gemini import generate

def generate_rubric(user_query):
    prompt = RUBRIC_GENERATION_PROMPT.replace("[INSERT USER QUERY HERE]", user_query)
    response, metadata = generate(prompt)
    return response, metadata

def process_query(args):
    index, question_id, query = args
    response, metadata = generate_rubric(query)
    print(f"Processed {index+1}/{len(queries)}: {question_id}")
    return {
        "question_id": question_id,
        "model_response": response,
        "metadata": metadata
    }

total_tokens = {
    "prompt_token_count": 0,
    "candidates_token_count": 0,
    "thoughts_token_count": 0
}

if __name__ == "__main__":
    with open("data/new.jsonl", "r") as f:
        data = [json.loads(line) for line in f]
    queries = [" ".join([i['content'][0]["text"] for i in item['conversation_a'] if i['role'] == 'user']) for item in data]
    question_ids = [item['question_id'] for item in data]

    process_args = [(i, qid, query) for i, (qid, query) in enumerate(zip(question_ids, queries))]

    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=40) as executor:
        future_to_query = {executor.submit(process_query, args): args for args in process_args}

        for future in concurrent.futures.as_completed(future_to_query):
            try:
                result = future.result()
                results.append(result)

                # Update token counts
                metadata = result["metadata"]
                total_tokens["prompt_token_count"] += metadata["prompt_token_count"]
                total_tokens["candidates_token_count"] += metadata["candidates_token_count"]
                total_tokens["thoughts_token_count"] += metadata["thoughts_token_count"]

            except Exception as e:
                print(f"Error processing query: {e}")
    
    print(f"Total tokens: {total_tokens}")
    with open("data/annotations/rubric_tree.jsonl", "a+") as f:
        for result in results:
            f.write(json.dumps(result, ensure_ascii=False) + "\n")
    
    print(f"Results saved to data/annotations/rubric_tree.jsonl")