import argparse
import json
import os
import re
import time
from tqdm import tqdm
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

# Switch to the directory where the current script is located
os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    """Reads a prompt from a file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

# --- MODIFIED: Load the new prompt for SOLVING questions ---
LLM_SOLVE_QA_PROMPT = read_prompt_from_file("./prompt/llm_generate_qa.txt")

openai_api_key = ""  # Use your actual key or environment variable
client = OpenAI(api_key=openai_api_key, base_url='') # Example for local vLLM

def evaluate_answer(pred_ans_str, gt_ans_str) -> bool:
    """
    Judges whether the model's predicted answer is correct.
    Handles string and numerical comparisons.
    """
    pred_clean = str(pred_ans_str).strip().lower()
    gt_clean = str(gt_ans_str).strip().lower()

    # 1. Exact string comparison
    if pred_clean == gt_clean:
        return True

    # 2. Numerical approximation comparison
    try:
        # Clean the string to extract a number (handles units, symbols, etc.)
        pred_num_match = re.search(r'[-+]?\d*\.\d+|\d+', pred_clean)
        gt_num_match = re.search(r'[-+]?\d*\.\d+|\d+', gt_clean)
        
        if pred_num_match and gt_num_match:
            p_val = float(pred_num_match.group(0))
            g_val = float(gt_num_match.group(0))
            # Allow for a small relative tolerance (e.g., 1%)
            if abs(p_val - g_val) <= 0.5 * abs(g_val):
                return True
    except (ValueError, TypeError):
        # If conversion fails, it's not a number match
        pass

    return False

# Place this near the top of your script.

# Template for the USER's message content.
# It contains the instructions and a placeholder for the specific question.
USER_CONTENT_TEMPLATE = """<image>
========================================
ROLE
========================================
You are an expert vision-language analyst.  
Your job is to look at the image, read the question, and provide a answer.

========================================
CRITICAL RULES (must follow all)
========================================
1.  **STEPBYSTEP THINKING:** You need to think step-by-step first before answering the question.Your thought process (which you may output in the <think> tag) should explicitly focus on:
    *   **Axes:** What do the horizontal (X-axis) and vertical (Y-axis) represent? Note their labels, units, and scale.
    *   **Data Points:** Locate the specific bars, points, lines, or other points relevant to the question.
    *   **Context:** Read the chart's title, legend, and any other text to fully understand the context.
2.  **FINAL ANSWER** Your output MUST contain the answer tag: `<answer>your answer</answer>`.
3.  **STRICT FORMAT:** The answer inside the `<answer>` tag must be the final, concise result (e.g., a single number). Do not include explanations or units unless required by the chart's notation.

========================================
INPUT FIELDS
========================================
Question      : {question}"""

# Template for the ASSISTANT's message content.
# It contains placeholders for the generated thought and answer.
ASSISTANT_CONTENT_TEMPLATE = """<think>{thought}</think>
<answer>{answer}</answer>"""

def process_plot(plot, client, model_name):
    """
    For a given plot, iterate through its existing questions. For each question,
    use the LLM to generate a visual CoT and an answer. If the answer
    matches the ground truth, store the result.
    """
    # --- START: MODIFIED SECTION TO LOAD PYTHON CODE FROM FILE ---
    
    plot_id = plot.get("plot_id")
    if not plot_id:
        print("Skipping entry because it is missing a 'plot_id'.")
        return [], plot

    try:
        # Extract the numerical part from the plot_id (e.g., "reachqa-train-plot-14838" -> "14838")
        code_id = plot_id.split('-')[-1]
        
        # Define the base path for the code files
        code_base_path = ""
        
        # Construct the full path to the python file
        code_file_path = os.path.join(code_base_path, f"{code_id}.py")

        # Read the python code from the file
        with open(code_file_path, 'r', encoding='utf-8') as f:
            python_code = f.read()
            
    except FileNotFoundError:
        print(f"Error: Python code file not found for plot_id '{plot_id}'. Searched at: {code_file_path}")
        return [], plot
    except (IndexError, TypeError):
        print(f"Error: Could not extract a valid code ID from plot_id '{plot_id}'.")
        return [], plot
    final_json_objects = []

    # --- MAJOR CHANGE: Loop through existing questions in the data ---
    # Assumes your data has a "QA" key with "question_list" and "answer_list"
    questions = plot.get("QA", {}).get("question_list", [])
    ground_truths = plot.get("QA", {}).get("answer_list", [])

    for question, gt_ans in zip(questions, ground_truths):
        print(f"Processing Q: '{question[:80]}...' for plot {plot.get('plot_id', 'N/A')}")

        # --- MODIFIED: Format the prompt with BOTH python_code and question ---
        formatted_prompt = LLM_SOLVE_QA_PROMPT.format(
            python_code=python_code,
            question=question
        )
        try:
            message = client.chat.completions.create(
                model=model_name,
                messages=[
                    {"role": "user", "content": formatted_prompt}
                ],
                temperature=0.7, # Lower temperature for more deterministic outputs
                top_p=0.8,
                timeout=60.0,
                max_tokens=1024
            )
            
            response_text = message.choices[0].message.content

            # Define a regex pattern that matches either <think> or <tool_call> for the thought process.
            # The `(think|tool_call)` part captures which tag was used.
            # The `\1` at the end is a backreference that ensures the closing tag matches the opening one.
            thought_pattern = r'<(think|tool_call)>\s*(.*?)\s*</\1>'
            
            # The answer pattern remains the same.
            answer_pattern = r'<answer>\s*(.*?)\s*</answer>'

            # Perform the searches
            match_thought = re.search(thought_pattern, response_text, re.S)
            match_answer = re.search(answer_pattern, response_text, re.S)

            # Check if we successfully found both a thought (in either format) and an answer.
            if not match_thought or not match_answer:
                print(f"  Missing a valid thought (<think> or <tool_call>) or <answer> tag. Retrying...")
                continue

            # Extract the content.
            # Note: The actual thought content is now in group(2) of the match_thought object.
            # group(1) would be the tag name itself ("think" or "tool_call").
            pred_think = match_thought.group(2).strip()
            pred_ans   = match_answer.group(1).strip()

            # --- NEW: Evaluate the generated answer against ground truth ---
            if evaluate_answer(pred_ans, gt_ans):
                print(f"  [Correct] GT: {gt_ans}, Pred: {pred_ans}. Saving CoT.")
                user_content = USER_CONTENT_TEMPLATE.format(question=question)
                assistant_content = ASSISTANT_CONTENT_TEMPLATE.format(
                    thought=pred_think,
                    answer=pred_ans
                )

                # 2. Get the absolute image path
                image_path = plot["image"]

                # 3. Build the final JSON object in the desired multi-turn format
                final_json_obj = {
                    "messages": [
                        {
                            "role": "user",
                            "content": user_content
                        },
                        {
                            "role": "assistant",
                            "content": assistant_content
                        }
                    ],
                    "images": [image_path]
                }
                
                # 4. Add this complete object to our list and break the retry loop
                final_json_objects.append(final_json_obj)
            else:
                print(f"  [Incorrect] GT: {gt_ans}, Pred: {pred_ans}. Retrying...")
        
        except Exception as e:
            continue

    return final_json_objects

# CORRECTED and IMPROVED function
def generate_validated_cot_data(client, data_path, num_data, num_workers, model_name, output_filename):
    """
    Main function to orchestrate the validated CoT data generation process.
    """
    output_file_path = os.path.join(data_path, output_filename)
    meta_file = os.path.join(data_path, "qa_data.jsonl")

    if not os.path.exists(meta_file):
        print(f"Error: Input data file not found at {meta_file}")
        return

    try:
        with open(meta_file, "r", encoding="utf-8") as f:
            meta_data = [json.loads(line) for line in f][:num_data]
    except Exception as e:
        print(f"Error reading or parsing meta_file at {meta_file}: {e}")
        return

    # --- FIX: Open the output file BEFORE the ThreadPoolExecutor ---
    # This makes `f_out` available inside the loop.
    with open(output_file_path, "w", encoding="utf-8") as f_out:
        total_qa_pairs_saved = 0
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            futures = {
                executor.submit(process_plot, plot, client, model_name): plot
                for plot in meta_data
            }
            
            for future in tqdm(as_completed(futures), total=len(futures), desc="Generating Validated CoT"):
                try:
                    # 'json_objects' is a list from process_plot, one for each valid QA pair.
                    json_objects = future.result()
                    
                    # Write each valid JSON object from the result to the file.
                    for json_obj in json_objects:
                        f_out.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
                        total_qa_pairs_saved += 1
                except Exception as e:
                    # Added the error name for better debugging
                    print(f"Error processing a future result: {type(e).__name__} - {e}")

    # The file is automatically closed by the `with` statement.
    # The final print statements are now outside the `with` block.
    print(f"\n[Finished] All validated CoT data has been saved to {output_file_path}")
    print(f"Successfully generated and saved {total_qa_pairs_saved} individual QA pairs.")

def arg_parser():
    parser = argparse.ArgumentParser(description="Generate validated Chart CoT data using an LLM.")
    parser.add_argument("--data_path", type=str, default='/dfs/data/CR/sft_train_1', help="Path to the directory containing qa_data.jsonl.")
    parser.add_argument("--num_data", type=int, default=100000, help="Maximum number of plots to process.")
    parser.add_argument("--num_workers", type=int, default=20, help="Number of parallel workers.")
    parser.add_argument("--model_name", type=str, default="", help="Name of the OpenAI model to use.")
    parser.add_argument("--output_filename", type=str, default="llm_validated_cot.jsonl", help="Name for the output file.")
    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()
    generate_validated_cot_data(
        client, 
        args.data_path, 
        args.num_data, 
        args.num_workers, 
        args.model_name,
        args.output_filename
    )