import argparse
import json
import os
import torch
import base64
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers import set_seed
set_seed(42) 

from evaluate import evaluate_accuracy

# Switch to the directory where the current script is located
os.chdir(os.path.dirname(os.path.realpath(__file__)))

def read_prompt_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().strip()
    
def file_to_data_url(file_path: str):
    """
    Convert a local image file to a data URL.
    """
    with open(file_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

    _, extension = os.path.splitext(file_path)
    mime_type = f"image/{extension[1:].lower()}"

    return f"data:{mime_type};base64,{encoded_string}"

ANSWER_PROMPT = read_prompt_from_file("./make_data/prompt/chart_answer.txt")
ANSWER_PROMPT_newcot = read_prompt_from_file("./test/prompt/cot_nothink.txt")
stepbystep_prompt = read_prompt_from_file("./test/prompt/cot_step.txt")
ANSWER_PROMPT_cot = read_prompt_from_file("./prompt/cot_answer.txt")

def process_plot(plot, answer_prompt, data_path, model, processor):

    image_path = plot["image"]
    if not os.path.isabs(image_path):
        image_path = os.path.join(data_path, image_path)
    
    answers = []
    # Process each question
    for question in plot["QA"]["question_list"]:
        # Construct messages list:
        # 1. System message: contains answer_prompt (can format question within it)
        # 2. User message: first element is image, second element is text question
        messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image_path},
                        {"type": "text", "text": question},
                    ]
                },
        ]
        # messages=[
        #     {"role": "system", "content": answer_prompt},
        #     {
        #         "role": "user",
        #         "content": [
        #             {"type": "image", "image": image_path},
        #             {"type": "text", "text": ANSWER_PROMPT_cot + f'Here is the question: {question} \n '},
        #         ]
        #     },
        # ]
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, _ = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=4096)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )   

        answers.append(output_text[0])
        print(output_text[0])
    return answers, plot

def generate_answer_data(model, processor, data_path, num_data, num_workers, cot, output_file_path):

    last_processed_plot_id = None
    if os.path.exists(output_file_path) and os.path.getsize(output_file_path) > 0:
        with open(output_file_path, "r", encoding="utf-8") as f:
            existing_data = [json.loads(line) for line in f]
        if existing_data:
            last_processed_plot_id = max(existing_data, key=lambda x: x["plot_id"])["plot_id"]
    
    meta_file = os.path.join(data_path, "qa_data.jsonl")
    with open(meta_file, "r", encoding="utf-8") as f:
        meta_data = [json.loads(line) for line in f][:num_data]
    print(f"Loaded {len(meta_data)} plots for answer generation.")

    start_index = 0
    if last_processed_plot_id:
        for index, plot in enumerate(meta_data):
            if plot["plot_id"] == last_processed_plot_id:
                start_index = index + 1
                break
    print(f"Skipped {start_index} processed plot(s).")
    
    all_results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        if cot == 'cot':
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT_cot, data_path, model, processor): plot
                for plot in meta_data[start_index:]
            }
        elif cot == 'nothink':
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT_newcot, data_path, model, processor): plot
                for plot in meta_data[start_index:]
            }
        elif cot == 'step':
            futures = {
                executor.submit(process_plot, plot, stepbystep_prompt, data_path, model, processor): plot
                for plot in meta_data[start_index:]
            }
        else:
            futures = {
                executor.submit(process_plot, plot, ANSWER_PROMPT, data_path, model, processor): plot
                for plot in meta_data[start_index:]
            }
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing plots"):
            try:
                answers, plot_info = future.result()
                all_results.append({
                    "plot_id": plot_info["plot_id"],
                    "image": plot_info["image"],
                    "chart_type": plot_info["chart_type"],
                    "QA": {
                        # Keep the original Q&A pairs
                        "question_list": plot_info["QA"]["question_list"],
                        "answer_list":   plot_info["QA"]["answer_list"],
                        # Put the model's answers into the same structure
                        "model_answer": answers,
                    }
                })
            except Exception as e:
                print("Error processing a plot:", e)
    
    with open(output_file_path, "a", encoding="utf-8") as f:
        for sample in all_results:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    print(f"All answer data has been saved to {output_file_path}")

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gen", type=bool, default=True)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--num_data", type=int, default=20)
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument("--cot", type=str, default='cot')
    parser.add_argument("--model", type=int, default='3')
    parser.add_argument("--save_path", type=str)
    
    return parser.parse_args()

if __name__ == "__main__":
    args = arg_parser()

    if args.gen:
        if args.model==3:
            mode_name='Qwen/Qwen2.5-VL-3B-Instruct'
            model_path='Qwen2.5-VL-3B-Instruct'
        elif args.model==7:
            mode_name='Qwen/Qwen2.5-VL-7B-Instruct'
            model_path='Qwen2.5-VL-7B-Instruct'
        elif args.model==32:
            mode_name='Qwen/Qwen2.5-VL-32B-Instruct'
            model_path = "Qwen2.5-VL-32B-Instruct"

        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
        processor = AutoProcessor.from_pretrained(model_path)

        generate_answer_data(model, processor, args.data_path, args.num_data, args.num_workers, args.cot, args.save_path)

    evaluate_accuracy(args.save_path)

# Preparation for inference
