import argparse
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import requests
import base64

import regex as re
# from rouge_score import rouge_scorer
from tqdm import tqdm
from openai import OpenAI

os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

desc_prompt = read_prompt_from_file("./prompt/desc_prompt_image.txt")

def file_to_data_url(file_path: str):
    """
    Convert a local image file to a data URL.
    """
    with open(file_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

    _, extension = os.path.splitext(file_path)
    mime_type = f"image/{extension[1:].lower()}"

    return f"data:{mime_type};base64,{encoded_string}"


def extract_description(input_str):
    pattern = r'"description"\s*:\s*"((?:\\.|[^"\\])*)"'
    match = re.search(pattern, input_str, re.DOTALL)
    if match:
        return match.group(1)
    return None

def process_plot(plot, desc_prompt, data_path, client):
    image_path = plot["image"]
    if not os.path.isabs(image_path):
        image_path = os.path.join(data_path, image_path)
    
    message = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": 'You are both an expert chartist and a professional maths teacher.'},
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": file_to_data_url(image_path)}},
                    {"type":"text","text":desc_prompt}
                ]
            },
        ]
    )
    #t1 = time.time()
    desc= message.choices[0].message.content
    desc_text= extract_description(desc)

    #t2 = time.time()
    #print(f"Response time for question '{question}': {t2 - t1:.2f} seconds")
    #print(desc)
    return desc_text, plot

def generate_desc_data(client, data_path, num_data=200, num_workers=10):
    output_file_path = os.path.join(data_path, "desc_data.jsonl")

    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r", encoding="utf-8") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]

    meta_file = os.path.join(data_path, "all_info.jsonl")
    with open(meta_file, "r", encoding="utf-8") as f:
        meta_data = [json.loads(line) for line in f]
    print(f"Loaded {len(meta_data)} plots for answer generation.")

    # Remove poorly performing charts
    meta_data = [plot for plot in meta_data if list(plot["rating"].values())[0] > 3]
    print(f"Loaded {len(meta_data)} collected plot(s)")

    # Skip processed images (if any)
    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["id"] == last_processed_plot_id:
                start_index = index + 1
                break
    print(f"Skipped {start_index} processed plot(s).")

    all_results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {
            executor.submit(process_plot, plot, desc_prompt, data_path, client): plot
            for plot in meta_data[start_index:]
        }
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing plots"):
            try:
                answers, plot_info = future.result()
                all_results.append({
                    "plot_id": plot_info["id"],
                    "image": plot_info["image"],
                    "level": plot_info["level"],
                    "description": f'This chart is a {plot_info["minor_chart_type"]}'+ answers
                })
            except Exception as e:
                print("Error processing a plot:", e)

    with open(output_file_path, "a", encoding="utf-8") as f:
        for sample in all_results:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    print(f"All answer data has been saved to {output_file_path}")

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--num_data", type=int, default=4000)
    parser.add_argument("--num_workers", type=int, default=10)
    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()

    openai_api_key = "EMPTY"
    client = OpenAI(
        api_key=openai_api_key,
    )

    generate_desc_data(client, args.data_path, args.num_data, args.num_workers)
