import argparse
import json
import os
import time
import regex as re
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from openai import OpenAI

os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()

# 从文件中读取 prompt 模板
QA_PROMPT = read_prompt_from_file("./prompt/QA_prompt.txt")
DESC_PROMPT = read_prompt_from_file("./prompt/desc_prompt_image.txt")

openai_api_key = "EMPTY"

openai_client = OpenAI(
    api_key=openai_api_key,
)

def is_numeric(value):
    """判断字符串能否转换为数字"""
    try:
        float(value)
        return True
    except ValueError:
        return False

def contains_year(text):
    """
    检查文本中是否包含年份（假设年份范围在1900-2100）
    """
    if not isinstance(text, str):
        # 对非字符串情况进行处理，比如记录警告或返回 False
        return False
    
    pattern = r'\b(19\d{2}|20\d{2}|2100)\b'
    return re.search(pattern, text) is not None

def filter_numeric_qa_for_plot(plot_data):
    """
    对给定 plot_data 中的 QA 进行筛选，保留答案为数字且问题和答案中均不包含年份的 QA 对。
    期望的 QA 格式：
      {
         "question_list": [...],
         "answer_list": [...]
      }
    返回更新后的 plot_data，其中 QA 部分只包含符合条件的 QA 对。
    """
    qa = plot_data.get("QA", {})
    questions = qa.get("question_list", [])
    answers = qa.get("answer_list", [])
    
    filtered_questions = []
    filtered_answers = []
    
    for q, a in zip(questions, answers):
        # 答案必须为数字
        if not is_numeric(a):
            continue
        # 问题和答案中不能包含年份
        if contains_year(a):
            continue
        
        filtered_questions.append(q)
        filtered_answers.append(a)
    
    plot_data["QA"]["question_list"] = filtered_questions
    plot_data["QA"]["answer_list"] = filtered_answers
    return plot_data

def extract_and_validate_json(input_str):
    # 用正则表达式提取 JSON 子串
    json_pattern = r'\{(?:[^{}]|(?R))*\}'
    json_match = re.search(json_pattern, input_str, re.DOTALL)

    if json_match:
        json_str = json_match.group(0)
        json_str = json_str.replace('\\', '\\\\')
        json_str = json_str.replace('\n', ' ').replace('\r', ' ')
        try:
            temp_dict = json.loads(json_str)
            return temp_dict
        except json.JSONDecodeError as e:
            print("JSON解析错误:", e)
        return None
    else:
        print("未在输入字符串中找到 JSON。")
        return None

def create_json_mode_chat_response_by_messages(
        model="claude-3-5-sonnet-20241022",
        client=None,
        messages=None,
        max_tokens=1000,
        temperature=0,
        max_retries=3,
):
    t1 = time.time()

    # 这里 model 固定为一个路径
    model = 'gpt-4o'
    messages = [
        {"role": "system", "content": messages[0]['role']},
        {"role": "user", "content": messages[0]['content']},
    ]

    message = openai_client.chat.completions.create(
        model=model,
        messages=messages
    )

    t2 = time.time()
    print('########################### result, time:', t2 - t1)
    print(message.choices[0].message.content)

    return message.choices[0].message.content

def generate_instruction_data(
        model,
        data_path,
        num_workers=5,
        num_data=200,
        prompt=DESC_PROMPT
):
    output_file_path = os.path.join(data_path, "qa_data.jsonl")

    # 如果已有数据，则获取最后一个 plot_id
    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]

    # 读取元数据
    meta_data = []
    with open(os.path.join(data_path, "all_info.jsonl"), "r") as f:
        for line in f:
            item = json.loads(line)
            image_path = item.get("image", "")
            if os.path.exists(image_path):
                meta_data.append(item)

    #去除效果较差的图表
    meta_data = [plot for plot in meta_data if list(plot["rating"].values())[0] > 4]
    #meta_data = meta_data[:10]
    print(f"Loaded {len(meta_data)} collected plot(s)")

    # 跳过已处理的数据
    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["id"] == last_processed_plot_id:
                start_index = index + 1
                break

    print(f"Skipped {start_index} processed plot(s)")

    def process_plot(plot, question_template):
        code_file_path = os.path.join(data_path, plot["code"])

        if not os.path.isfile(code_file_path):
            print(f"[WARN] Code file not found: {code_file_path}. Skipping plot id {plot['id']}.")
            return {"QA": {"question_list": [], "answer_list": []}}, plot
        
        with open(code_file_path, "r") as f:
            code = f.read()

        max_attempts = 10
        new_dict = None
        # 尝试生成符合要求的 QA 数据，最多尝试 max_attempts 次
        for attempt in range(max_attempts):
            print(f"\nCalling LLM for Generate Questions for plot id {plot['id']} (attempt {attempt+1})...")
            question_dict_string = create_json_mode_chat_response_by_messages(
                model=model,
                messages=[{"role": "user", "content": question_template.format(code=code)}],
                max_tokens=8192,
            )

            describe = extract_and_validate_json(question_dict_string)
            if not describe:
                print(f"Attempt {attempt+1}: 未成功解析 JSON，重试...")
                continue

            new_dict = {"QA": describe}
            # 过滤 QA 数据，只保留答案为数字且不含年份的项
            filtered_dict = filter_numeric_qa_for_plot(new_dict)
            qa_data = filtered_dict.get("QA", {})
            if qa_data.get("question_list") and len(qa_data.get("question_list")) > 0:
                print(f"Plot id {plot['id']} 得到符合要求的 QA 数据")
                return filtered_dict, plot
            else:
                print(f"Attempt {attempt+1}: 生成的 QA 数据未满足要求，重试...")

        # 如果经过 max_attempts 次后仍然不符合要求，可以选择返回空数据或最后一次的结果
        print(f"Plot id {plot['id']} 经过 {max_attempts} 次尝试后依然未生成符合要求的 QA 数据，返回空 QA。")
        # 返回空 QA 数据（question_list 和 answer_list 均为空）
        return {"QA": {"question_list": [], "answer_list": []}}, plot

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {
            executor.submit(process_plot, plot, prompt): plot
            for plot in meta_data[start_index:]
        }
        # 打开文件，追加写入
        with open(output_file_path, "a") as f:
            for future in tqdm(as_completed(futures), total=len(futures)):
                new_instructions, plot = future.result()

                # 检查图表文件是否存在，不存在则跳过该 plot
                image_path = os.path.join(data_path, plot["image"])
                if not os.path.exists(image_path):
                    print(f"Image file {image_path} does not exist. Skipping plot id {plot['id']}.")
                    continue

                sample = {
                    "plot_id": plot["id"],
                    "image": plot["image"],
                    "chart_type": plot["major_chart_type"],
                    "QA": new_instructions["QA"],
                }
                f.write(json.dumps(sample) + "\n")

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--api_key", type=str, default="key")
    parser.add_argument("--base_url", type=str, default="url")
    parser.add_argument("--data_path", type=str, default=" ")
    parser.add_argument("--num_instruction_per_plot", type=int, default=3)
    parser.add_argument("--num_workers", type=int, default=30)
    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()
    generate_instruction_data(
        model=args.model_name,
        data_path=args.data_path,
        num_workers=args.num_workers,
        num_data=200,
        prompt=QA_PROMPT
    )
