import argparse
import json
import os
import time
import base64
from tqdm import tqdm
from transformers import set_seed
set_seed(42)     
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI
from evaluate import evaluate_accuracy

# Switch to the directory where the current script is located
os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

def file_to_data_url(file_path: str):
    """
    Convert a local image file to a data URL.
    """
    with open(file_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

    _, extension = os.path.splitext(file_path)
    mime_type = f"image/{extension[1:].lower()}"

    return f"data:{mime_type};base64,{encoded_string}"

prompt_cot='''### Output Format  

You should first thinks about the reasoning process internally and then provides the user with the answer. The **reasoning process** and **answer** are enclosed within specific tags:  

- **Reasoning process**: Enclosed within `<think>...</think>`  
- **Final answer (sequence of functions only)**: Enclosed within `<answer>...</answer>`  

Now, it's your turn!'''

# Read answer prompt from file, for example:
# "Please provide a direct answer to the question without any additional explanations or comments. Only output the final answer, keeping the format concise and clear."
ANSWER_PROMPT = read_prompt_from_file("./make_data/prompt/chart_answer.txt")
ANSWER_PROMPT_newcot = read_prompt_from_file("./test/prompt/cot_nothink.txt")
stepbystep_prompt = read_prompt_from_file("./test/prompt/cot_step.txt")
lmmr1_prompt = read_prompt_from_file("./test/prompt/lmm-r1.txt") 
ANSWER_PROMPT_cot = read_prompt_from_file("./prompt/cot_answer.txt")
r1_prompt= read_prompt_from_file("./prompt/cot_answer.txt")

def process_plot(plot, answer_prompt, data_path, client):
    # If the image field is a relative path, concatenate the absolute path based on data_path
    image_path = plot["image"]
    if not os.path.isabs(image_path):
        image_path = os.path.join(data_path, image_path)
    # Convert image to data URL format
    image_data_url = file_to_data_url(image_path)
    answers = []
    # Process each question
    for question in plot["QA"]["question_list"]:
        # Construct message: system message as answer prompt, user message contains both image and question text
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                # {"role": "system", "content": "You are a helpful assistant that responds in Markdown. Help me with my math homework!"},
                {"role": "user", "content": [
                    {"type": "image_url", "image_url": {"url": image_data_url}},
                    {"type": "text", "text": question},
                ]}
            ],
            #temperature=0.0,
        )
        t1 = time.time()
        answer = response.choices[0].message.content
        # message = client.chat.completions.create(
        #     model=model,
        #     messages=[
        #         {
        #             "role": "user",
        #             "content": [
        #                 {"type": "image_url", "image_url": {"url": image_data_url}},
        #                 {"type": "text", "text": prompt_cot+ '\n' + question + '\nOutput the thinking process in <think> </think> and final answer (float number or int number) in <answer> </answer> tags'},
        #             ]
        #         },
        #     ]
        # )
        t2 = time.time()
        print(f"Response time for question '{question}': {t2 - t1:.2f} seconds")
        print(answer)
        answers.append(answer)
    return answers, plot

def generate_answer_data(client, data_path, num_data, cot, save_path):
    # Output file path
    output_file_path = save_path

    # If output file already exists, read the last processed plot_id
    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r", encoding="utf-8") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]

    # Load original data (assuming the file is all_describe_data_qwen.jsonl, with one JSON object per line)
    meta_file = os.path.join(data_path, "qa_data.jsonl")
    with open(meta_file, "r", encoding="utf-8") as f:
        meta_data = [json.loads(line) for line in f][:num_data]
    print(f"Loaded {len(meta_data)} plots for answer generation.")

    # Skip already processed images (if any)
    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["plot_id"] == last_processed_plot_id:
                start_index = index + 1
                break
    print(f"Skipped {start_index} processed plot(s).")

    # Use thread pool to process each plot concurrently
    all_results = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        if cot == 'cot':
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT_cot, data_path, client): plot
                for plot in meta_data[start_index:]
            }
        elif cot == 'nothink':
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT_newcot, data_path, client): plot
                for plot in meta_data[start_index:]
            }
        elif cot == 'step':
            futures = {
                executor.submit(process_plot, plot, stepbystep_prompt, data_path, client): plot
                for plot in meta_data[start_index:]
            }
        elif cot == 'lmm':
            futures = {
                executor.submit(process_plot, plot, lmmr1_prompt, data_path, client): plot
                for plot in meta_data[start_index:]
            }
        else:
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT, data_path, client): plot
                for plot in meta_data[start_index:]
            }
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing plots"):
            try:
                answers, plot_info = future.result()
                all_results.append({
                    "plot_id": plot_info["plot_id"],
                    "image": plot_info["image"],
                    "chart_type": plot_info["chart_type"],
                    "QA": {
                        # Keep the original Q&A pairs
                        "question_list": plot_info["QA"]["question_list"],
                        "answer_list":   plot_info["QA"]["answer_list"],
                        # Put the model's answers into the same structure
                        "model_answer": answers,
                    }
                })
            except Exception as e:
                print("Error processing a plot:", e)

    # Write results to output file (append mode)
    with open(output_file_path, "a", encoding="utf-8") as f:
        for sample in all_results:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    print(f"All answer data has been saved to {output_file_path}")

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gen", type=bool, default=True)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--num_data", type=int, default=10000)
    parser.add_argument("--cot", type=str, default='lmm')
    parser.add_argument("--save_path", type=str)
    
    return parser.parse_args()
    
if __name__ == "__main__":
    args = arg_parser()
    api_key = ""
    if args.gen:
        generate_answer_data(client, args.data_path, args.num_data, args.cot, args.save_path)

    evaluate_accuracy(args.save_path)

