import argparse
import json
import os
import time
import base64
from tqdm import tqdm
from transformers import set_seed
set_seed(42)     
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
from evaluate import evaluate_accuracy

# Switch to the directory where the current script is located
os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

def file_to_data_url(file_path: str):
    """
    Convert a local image file to a data URL.
    """
    with open(file_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

    _, extension = os.path.splitext(file_path)
    mime_type = f"image/{extension[1:].lower()}"

    return f"data:{mime_type};base64,{encoded_string}"

prompt_cot='''### Output Format  

You should first thinks about the reasoning process internally and then provides the user with the answer. The **reasoning process** and **answer** are enclosed within specific tags:  

- **Reasoning process**: Enclosed within `<think>...</think>`  
- **Final answer (sequence of functions only)**: Enclosed within `<answer>...</answer>`  

Now, it's your turn!
Quetion: {original_question}'''

prompt_chartqapro='''You are given a question that you need to answer based on the provided image.Your answer should be a single word, number, or phrase. If the question is unanswerable based on the information in the provided image, your answer should be unanswerable. Do not generate units.
But if numerical units such as million, m, billion, B, or K are required, use the exact notation shown in the chart.
Remember to generate the final answer only without any additional text!
Question: '''

prompt_chartqapro_cot='''You are given a question that you need to answer based on the provided image.
You need to think step-by-step, but your final answer should be a single word, number, or phrase. If
the question is unanswerable based on the information in the provided image, your answer should be
unanswerable. Do not generate units. But if numerical units such as million, m, billion, B, or K are
required, use the exact notation shown in the chart.
Remember to think step-by-step and format the final answer in a single number like "X", and the answer should be enclosed within `<answer>...</answer>`  
Question: '''

prompt_chartqapro_cot_format='''You are an expert at analyzing charts and providing concise answers. Your task is to answer a question based on an image by following these rules precisely.

**RULES:**
1.  **Analyze the image** and question to find the answer.
2.  **Think step-by-step** to arrive at your conclusion. This is your internal process.
3.  **Provide ONLY the final answer**. Your response must be a single word, number, or phrase, enclosed within `<answer>` tags.
4.  **Do NOT include any reasoning, explanation, or introductory text in your final output.**

**EXAMPLE:**
Question:
(Imagine an image of a bar chart showing city populations)
"What is the population of Tokyo?"
Your Correct Response:
<answer>37.3</answer>

**YOUR TASK:**
Question: '''

direct='''========================================
ROLE
========================================
You are an expert vision-language analyst.  
Your job is to look at the image, read the question, and provide a answer.

========================================
CRITICAL RULES (must follow all)
========================================
1.  **FINAL ANSWER** Your output MUST contain the answer tag: `<answer>your answer</answer>`.
2.  **STRICT FORMAT:** The answer inside the `<answer>` tag must be the final, concise result (e.g., a single number). Do not include explanations or units unless required by the chart's notation.

========================================
INPUT FIELDS
========================================
Question: {original_question}  
'''

cot_output='''
========================================
ROLE
========================================
You are an expert vision-language analyst.  
Your job is to look at the image, read the question, and provide a answer.

========================================
CRITICAL RULES (must follow all)
========================================
1.  **STEPBYSTEP THINKING:** You need to think step-by-step first before answering the question.Your thought process (which you must output in the <think> tag) should explicitly focus on:
    *   **Axes:** What do the horizontal (X-axis) and vertical (Y-axis) represent? Note their labels, units, and scale.
    *   **Data Points:** Locate the specific bars, points, lines, or other points relevant to the question.
    *   **Context:** Read the chart's title, legend, and any other text to fully understand the context.
2.  **FINAL ANSWER** Your output MUST contain the answer tag: `<answer>your answer</answer>`.
3.  **STRICT FORMAT:** The answer inside the `<answer>` tag must be the final, concise result (e.g., a single number). Do not include explanations or units unless required by the chart's notation.

========================================
EXAMPLE
========================================
Question: "What is the population of Tokyo?"
(Imagine an image of a bar chart showing city populations)
Your Correct Response:
<think>To answer the question, I first look at the chart's X-axis, which represents different cities, and the Y-axis, which shows their populations in millions. The bar for Tokyo is approximately from 35 million to 40 million. So I estimate the population to be around 37.3 million.</think>
<answer>37.3</answer>
INPUT FIELDS
========================================
Question      : {original_question}  
'''

cot_output='''
========================================
ROLE
========================================
You are an expert vision-language analyst.  
Your job is to look at the image, read the question, and provide a answer.

========================================
CRITICAL RULES (must follow all)
========================================
1.  **STEPBYSTEP THINKING:** You need to think step-by-step first before answering the question.Your thought process (which you must output in the <think> tag) should explicitly focus on:
    *   **Axes:** What do the horizontal (X-axis) and vertical (Y-axis) represent? Note their labels, units, and scale.
    *   **Data Points:** Locate the specific bars, points, lines, or other points relevant to the question.
    *   **Context:** Read the chart's title, legend, and any other text to fully understand the context.
2.  **FINAL ANSWER** Your output MUST contain the answer tag: `<answer>your answer</answer>`.
3.  **STRICT FORMAT:** The answer inside the `<answer>` tag must be the final, concise result (e.g., a single number). Do not include explanations or units unless required by the chart's notation.

========================================
EXAMPLE
========================================
Question: "What is the population of Tokyo?"
(Imagine an image of a bar chart showing city populations)
Your Correct Response:
<think>To answer the question, I first look at the chart's X-axis, which represents different cities, and the Y-axis, which shows their populations in millions. The bar for Tokyo is approximately from 35 million to 40 million. So I estimate the population to be around 37.3 million.</think>
<answer>37.3</answer>
========================================
INPUT FIELDS
========================================
Question      : {original_question}  
'''
cot='''
========================================
ROLE
========================================
You are an expert vision-language analyst.  
Your job is to look at the image, read the question, and provide a answer.

========================================
CRITICAL RULES (must follow all)
========================================
1.  **STEPBYSTEP THINKING:** You need to think step-by-step first before answering the question.Your thought process (which you may output in the <think> tag) should explicitly focus on:
    *   **Axes:** What do the horizontal (X-axis) and vertical (Y-axis) represent? Note their labels, units, and scale.
    *   **Data Points:** Locate the specific bars, points, lines, or other points relevant to the question.
    *   **Context:** Read the chart's title, legend, and any other text to fully understand the context.
2.  **FINAL ANSWER** Your output MUST contain the answer tag: `<answer>your answer</answer>`.
3.  **STRICT FORMAT:** The answer inside the `<answer>` tag must be the final, concise result (e.g., a single number). Do not include explanations or units unless required by the chart's notation.

========================================
INPUT FIELDS
========================================
Question      : {original_question}  
'''

# Read answer prompt from file, for example:
# "Please provide a direct answer to the question without any additional explanations or comments. Only output the final answer, keeping the format concise and clear."
ANSWER_PROMPT = read_prompt_from_file("./make_data/prompt/chart_answer.txt")
ANSWER_PROMPT_newcot = read_prompt_from_file("./test/prompt/cot_nothink.txt")
stepbystep_prompt = read_prompt_from_file("./test/prompt/cot_step.txt")
lmmr1_prompt = read_prompt_from_file("./test/prompt/lmm-r1.txt") 
ANSWER_PROMPT_cot = read_prompt_from_file("./prompt/cot_answer.txt")
r1_prompt= read_prompt_from_file("./prompt/cot_answer.txt")
make_data_prompt = read_prompt_from_file("./chart_cot/prompt/combined_cot.txt")

def process_plot(plot, answer_prompt, data_path, client, model):
    # If the image field is a relative path, concatenate the absolute path based on data_path
    image_path = plot["image"]
    if not os.path.isabs(image_path):
        image_path = os.path.join(data_path, image_path)
    # Convert image to data URL format
    image_data_url = file_to_data_url(image_path)
    answers = []
    # Process each question
    for question in plot["QA"]["question_list"]:
        if answer_prompt == 'cot':
        # Construct message: system message as answer prompt, user message contains both image and question text
        # message = client.chat.completions.create(
        #     model=model,
        #     messages=[
        #         {"role": "system", "content": "You are an expert in chart analysis"},#answer_prompt},
        #         {
        #             "role": "user",
        #             "content": [
        #                 {"type": "image_url", "image_url": {"url": image_data_url}},
        #                 {"type": "text", "text": question},
        #                 #{"type": "text", "text": 'Please describe this chart in one sentence'}
        #             ]
        #         },
        #     ]
        # )
            print("cot")
            message = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are an expert in chart analysis"},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image_url", "image_url": {"url": image_data_url}},
                            {"type": "text", "text": cot.format(original_question=question)},
                        ]
                    },
                ]
            )
        else :
            print("dir")
            message = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are an expert in chart analysis"},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image_url", "image_url": {"url": image_data_url}},
                            #{"type": "text", "text": question},
                            {"type": "text", "text": direct.format(original_question=question)},
                        ]
                    },
                ]
            )
        t1 = time.time()
        answer= message.choices[0].message.content
        t2 = time.time()
        print(f"Response time for question '{question}': {t2 - t1:.2f} seconds")
        print(answer)
        answers.append(answer)
    return answers, plot

def generate_answer_data(client, data_path, num_data, num_workers, cot, model, save_path):
    # Output file path
    output_file_path = save_path

    # If output file already exists, read the last processed plot_id
    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r", encoding="utf-8") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]

    # Load original data (assuming the file is all_describe_data_qwen.jsonl, with one JSON object per line)
    meta_file = os.path.join(data_path, "qa_data.jsonl")
    with open(meta_file, "r", encoding="utf-8") as f:
        meta_data = [json.loads(line) for line in f][:num_data]
    print(f"Loaded {len(meta_data)} plots for answer generation.")

    # Skip already processed images (if any)
    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["plot_id"] == last_processed_plot_id:
                start_index = index + 1
                break
    print(f"Skipped {start_index} processed plot(s).")

    # Use thread pool to process each plot concurrently
    all_results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {
            executor.submit(process_plot, plot, cot, data_path, client, model): plot
            for plot in meta_data[start_index:]
        }
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing plots"):
            try:
                answers, plot_info = future.result()
                all_results.append({
                    "plot_id": plot_info["plot_id"],
                    "image": plot_info["image"],
                    "chart_type": plot_info["chart_type"],
                    "QA": {
                        # Keep the original Q&A pairs
                        "question_list": plot_info["QA"]["question_list"],
                        "answer_list":   plot_info["QA"]["answer_list"],
                        # Put the model's answers into the same structure
                        "model_answer": answers,
                    }
                })
            except Exception as e:
                print("Error processing a plot:", e)

    # Write results to output file (append mode)
    with open(output_file_path, "a", encoding="utf-8") as f:
        for sample in all_results:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    print(f"All answer data has been saved to {output_file_path}")

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gen", type=bool, default=True)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--num_data", type=int, default=10000)
    parser.add_argument("--num_workers", type=int, default=20)
    parser.add_argument("--cot", type=str, default='lmm')
    parser.add_argument("--model", type=int, default='7')
    parser.add_argument("--save_path", type=str)
    
    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()
    print("Current runtime parameters:")
    for name, value in vars(args).items():
        print(f"  {name}: {value}")
    # Set OpenAI API key and API Base (calling Qwen2.5 interface)
    if args.model == 3:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8010/v1"
        model="Qwen/Qwen2.5-VL-3B-Instruct"
        num_workers = 10
    elif args.model == 72:
        openai_api_key = "EMPTY"
        openai_api_base = "http://10.199.254.175:8000/v1"
        model="Qwen2.5-VL-72B-Instruct"
        num_workers = 15
    elif args.model == 7:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8007/v1"
        model="Qwen/Qwen2.5-VL-7B-Instruct"
        num_workers = 10
    elif args.model == 32:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8003/v1"
        model="Qwen/Qwen2.5-VL-32B-Instruct"
        num_workers = 10
    elif args.model == 16:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8011/v1"
        model="ovis"
        num_workers = 10
    elif args.model == 0 :
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8000/v1"
        model="internvl3"
        num_workers = 10
    elif args.model == 100:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8011/v1"
        model="gemma"
        num_workers = 10
    else:
        openai_api_key = "EMPTY"
        openai_api_base = "http://172.17.0.2:8005/v1"
        model="Qwen/Qwen2.5-VL-72B-Instruct"
        num_workers = 10
    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )
    if args.gen:
        generate_answer_data(client, args.data_path, args.num_data, num_workers, args.cot, model, args.save_path)

    evaluate_accuracy(args.save_path)

