import argparse
import json
import os
import re
from tqdm import tqdm
import random
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

os.chdir(os.path.dirname(os.path.realpath(__file__)))

check_prompt_template = (
    "Here are three pieces of content:\n\n"
    "1. Chart generation code (Python):\n"
    "```\n{chart_code}\n```\n\n"
    "2. Question:\n"
    "{question}\n\n"
    "3. Provided Answer:\n"
    "{provided_answer}\n\n"
    "Based on the chart produced by the above Python code, determine whether the “Provided Answer” is correct:\n"
    "- If it is correct, simply return “{provided_answer}.”\n"
    "- If it is not correct, please provide the “Correct Answer” (output only the numeric value, without any additional explanation).\n\n"
    "Note: Do not return any extra commentary; output only the final “Correct Answer.”"
)

openai_api_key = "EMPTY"
client = OpenAI(
    api_key=openai_api_key,
)

def check_and_correct_with_code(chart_code: str, question: str, provided_answer: str, id: str) -> str:
    prompt = check_prompt_template.format(
        chart_code=chart_code.strip(),
        question=question.strip().replace("\n", " "),
        provided_answer=provided_answer.strip()
    )
    
    try:
        completion = client.chat.completions.create(
            model='gpt-4o', 
            messages=[
                {"role": "system", "content": "You are a strict chart QA verification and correction assistant."},
                {"role": "user", "content": prompt}
            ],
        )
        result = completion.choices[0].message.content.strip()
        m = re.search(r"\d+(?:\.\d+)?", result)
        res = m.group(0)
        if res != provided_answer:
            print(f'Original answer: {provided_answer}, corrected answer: {res}, ID: {id}')
        return res
    
    except Exception as e:
        print(f"[Warning] LLM call failed: {e}\nPrompt preview (first 200 characters): {prompt[:200]}")
        return provided_answer

def worker_process_line(line: str, line_no: int, code_folder):
    """
    Worker function: For each line in the JSONL file:
      1) Parse JSON -> dict
      2) Iterate through QA list and use LLM to verify each (question, answer)
      3) Store results in a new list corrected_answer_list
      4) Return the corrected dict (as a JSON string), or None if JSON parsing fails
    """
    line = line.strip()
    if not line:
        return None

    try:
        plot_data = json.loads(line)
    except json.JSONDecodeError as e:
        print(f"[Line {line_no}] JSON parsing error: {e}, skipping this line")
        return None
    
    plot_id = plot_data.get("plot_id")
    chart = plot_data.get("chart_type")
    if chart == 'Node Charts':
           return None
    
    print(chart)
    image = plot_data.get("image")
    if not plot_id:
        return None

    match = re.search(r"(\d+)$", plot_id)
    if not match:
        print(f"[Line {line_no}] Could not extract digits from plot_id: {plot_id}, skipping")
        return None

    digits = match.group(1)
    code_filename = f"{digits}.py"
    code_path = os.path.join(code_folder, code_filename)

    if not os.path.isfile(code_path):
        print(f"[Line {line_no}] Code file not found: {code_path}, skipping this line")
        return None

    try:
        with open(code_path, "r", encoding="utf-8") as cf:
            chart_code = cf.read()
    except Exception as e:
        print(f"[Line {line_no}] Error reading code file: {e}, skipping this line")
        return None

    qa = plot_data.get("QA", {})
    questions = qa.get("question_list", [])
    answers = qa.get("answer_list", [])

    if not chart_code or not questions or not answers:
        return None

    corrected_list = []

    for idx, (q, provided_ans) in enumerate(zip(questions, answers)):
        corrected_ans = check_and_correct_with_code(chart_code, q, provided_ans, plot_id)
        corrected_list.append(corrected_ans)

    plot_data["QA"]["answer_list"] = corrected_list

    return json.dumps(plot_data, ensure_ascii=False)

# ------------------ Multithreaded processing of the entire file ------------------

def process_file_with_threads(file_path: str, max_workers: int, code_folder):
    """
    Parallel processing of a JSONL file:
      1) Read all lines
      2) Submit a worker task for each line to the thread pool
      3) Collect parallel results, filter out None
      4) Optionally shuffle results
      5) Overwrite the original file with the corrected records
    """
    if not os.path.exists(file_path):
        print(f"File {file_path} does not exist. Aborting.")
        return

    with open(file_path, "r", encoding="utf-8") as rf:
        all_lines = rf.readlines()

    results = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures_list = [
            executor.submit(worker_process_line, line, idx + 1, code_folder)
            for idx, line in enumerate(all_lines)
        ]

        futures_map = {future: idx + 1 for idx, future in enumerate(futures_list)}

        for future in tqdm(as_completed(futures_list),
                        total=len(futures_list),
                        desc="Processing QA"):
            line_no = futures_map[future]
            try:
                processed_json = future.result()
                if processed_json:
                    results.append(processed_json)
            except Exception as e:
                print(f"[Line {line_no}] Thread exception: {e}")
                
    random.shuffle(results)

    with open(file_path, "w", encoding="utf-8") as wf:
        for rec in results:
            wf.write(rec + "\n")

    print(f"Parallel processing completed. {len(results)} records written back to {file_path}")

# ------------------ Entry point ------------------

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", type=str)
    parser.add_argument("--code", type=str)

    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()
    process_file_with_threads(args.file_path, 25, args.code)