import argparse
import base64
import concurrent.futures
import json
import os
import random
import re
import time
from functools import partial
from io import BytesIO

import pandas as pd
import torch
from PIL import Image
from tqdm import tqdm

# Assuming qwen_vl_utils is a local file with the required function
# from qwen_vl_utils import process_vision_info
# Placeholder for the function if the file is not available.
def process_vision_info(messages):
    # This is a simplified placeholder.
    # The actual implementation in qwen_vl_utils.py would be more complex.
    image_paths = [
        content["image"]
        for message in messages
        for content in message["content"]
        if content["type"] == "image"
    ]
    # In a real scenario, this would return pre-processed image tensors.
    # For this script, the qwenQuery function handles image loading.
    return None, image_paths


from openai import OpenAI
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

# --- Prompt Formats ---

PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to answer it requiring step-by-step Chain-of-Thought (CoT) reasoning process. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.

Input Format:
Question: {original_question}
Original Answer: {original_answer}

Output Format:
Answer: [answer with reasoning steps]
<think>step-by-step reasoning process</think>
<answer>Original Answer here</answer>
"""

REVERSE_THINKING_PROMPT_FORMAT = """
Based on the following question and image, generate a detailed thought process to explain how to derive the answer from the inputs.
Question: {original_question}
Answer: {original_answer}

Requirements:
1. The reasoning must be step-by-step and clearly divided into points (1., 2., 3., etc.)
2. The total length must be at least 200 words
3. End with a clear summary and proper punctuation (must end with '.')
4. Do not output the answer, only generate the reasoning process.

Output Format:
1.
2.
...

"""


# --- Helper Functions ---


def get_image_data_url(image_input):
    """Encodes a PIL image or image path into a base64 data URL."""
    if isinstance(image_input, str):
        image_input = Image.open(image_input)

    if not isinstance(image_input, Image.Image):
        raise ValueError("Unsupported image input type")

    if image_input.mode != "RGB":
        image_input = image_input.convert("RGB")

    buffer = BytesIO()
    image_input.save(buffer, format="JPEG")
    img_bytes = buffer.getvalue()
    base64_data = base64.b64encode(img_bytes).decode("utf-8")
    return f"data:image/jpeg;base64,{base64_data}"


def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
    """Sends a query with an image to GPT-4o API with retry logic."""
    if image is None:
        return None

    data_url = get_image_data_url(image)

    client = OpenAI(
        # Replace with your actual API key and base URL if needed
        # base_url='https://api.openai-proxy.org/v1',
        # api_key='YOUR_API_KEY',
    )

    for attempt in range(max_retries):
        try:
            messages = [
                {
                    "role": "system",
                    "content": "You are an expert to analyze the image and provide useful information for users.",
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": data_url}},
                        {"type": "text", "text": prompt},
                    ],
                },
            ]

            response = client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                temperature=0.2,
                max_tokens=8192,
            )
            return response.choices[0].message.content

        except Exception as e:
            if attempt == max_retries - 1:
                print(f"GPT-4o query failed after {max_retries} attempts. Last error: {e}")
                raise
            delay = initial_delay * (2**attempt) + random.uniform(0, 0.1 * initial_delay * (2**attempt))
            time.sleep(delay)


def qwenQuery(image, prompt, model, processor):
    """Sends a query with an image to a local Qwen-VL model."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
    inputs = inputs.to(model.device)

    generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=8192)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text


def process_question(item):
    """Extracts the question from a data item."""
    return item["question"]


def process_item(item, dataset_name, image_dir, progress_bar, model, processor):
    """
    Processes a single data item to generate a Chain-of-Thought reasoning.
    This function contains conditional logic to handle different dataset structures.
    """
    try:
        # --- Conditional logic based on dataset_name ---
        original_question = process_question(item)

        if dataset_name == "slake":
            image_path = os.path.join(image_dir, str(item["img_name"]))
            original_answer = item["answer"]
            item_id = item["qid"]
        elif dataset_name == "path_vqa":
            image_path = os.path.join(image_dir, str(item["images"][0]))
            original_answer = item["response"]
            item_id = os.path.splitext(item["images"][0])[0]
        elif dataset_name == "vqa_rad":
            image_path = os.path.join(image_dir, item["img_name"])
            original_answer = item["answer"]
            item_id = os.path.splitext(item["img_name"])[0]
        else:
            raise ValueError(f"Unknown dataset processing logic for: {dataset_name}")

        formatted_prompt = REVERSE_THINKING_PROMPT_FORMAT.format(
            original_question=original_question, original_answer=original_answer
        )

        response = ""
        # Retry loop to ensure the generated output meets quality requirements
        for _ in range(50):
            res = qwenQuery(image_path, formatted_prompt, model, processor)[0]
            # Check for quality criteria: ends with a period, min word count, and has numbered steps.
            if res and res.endswith(".") and len(res.split()) >= 200 and re.search(r"\d+\.\s", res):
                response = res
                break
        
        if not response:
             print(f"Failed to generate valid CoT for item {item_id} after 50 attempts.")

        # Structure the output data
        new_item = {
            "id": f"train-{dataset_name}-{item_id}",
            "image": [image_path],
            "problem": original_question,
            "cot": response,
            "solution": f"<think>{response}</think><answer>{original_answer}</answer>",
            "answer": original_answer,
        }

        progress_bar.update(1)
        return new_item, None  # Successful result

    except Exception as e:
        error_message = f"Error processing item for dataset {dataset_name}: {e}\n"
        print(error_message.strip())

        with open(f"error_{dataset_name}.txt", "a", encoding="utf-8") as error_file:
            error_file.write(error_message)

        progress_bar.update(1)
        return None, error_message  # Return error message


def main(args):
    """Main function to run the data processing pipeline."""
    # --- Set default paths based on dataset name if not provided by the user ---
    if args.dataset_name == "slake":
        image_dir = args.image_dir or "/sda/duyuetian/dataset/SLAKE/imgs"
        input_json = args.input_json or "/sda/duyuetian/dataset/SLAKE/train.json"
        output_json = args.output_json or f"./qwen_{args.dataset_name}_train_cot.json"
    elif args.dataset_name == "path_vqa":
        image_dir = args.image_dir or "/sda/duyuetian/dataset/PATH-VQA"
        input_json = args.input_json or "/sda/duyuetian/dataset/PATH-VQA/train.json"
        output_json = args.output_json or f"./qwen_{args.dataset_name}_train_cot.json"
    elif args.dataset_name == "vqa_rad":
        image_dir = args.image_dir or "/sda/duyuetian/dataset/VQA-RAD/VQA_RAD_Image_Folder"
        input_json = args.input_json or "/sda/duyuetian/dataset/VQA-RAD/train.json"
        output_json = args.output_json or f"./qwen_{args.dataset_name}_train_cot.json"
    else:
        # This case is handled by argparse `choices`, but it's good practice to have a fallback.
        raise ValueError(f"Unknown dataset: {args.dataset_name}")

    # --- Load data ---
    with open(input_json, "r", encoding="utf-8") as f:
        data = json.load(f)

    if args.num_samples:
        data = data[: args.num_samples]

    # --- Initialize Model and Processor ---
    print("Loading Qwen-VL model...")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.model_path, torch_dtype="auto", device_map="auto"
    )
    processor = AutoProcessor.from_pretrained(args.model_path)
    print("Model loaded.")

    # --- Resume from checkpoint if output file exists ---
    if os.path.exists(output_json):
        with open(output_json, "r", encoding="utf-8") as f:
            data_with_cot = json.load(f)
    else:
        data_with_cot = []

    processed_count = len(data_with_cot)
    print(f"Resuming from item {processed_count}/{len(data)}...")

    # --- Main processing loop with ThreadPoolExecutor ---
    items_to_process = data[processed_count:]
    with tqdm(total=len(items_to_process)) as progress_bar:
        process_item_partial = partial(
            process_item,
            dataset_name=args.dataset_name,
            image_dir=image_dir,
            progress_bar=progress_bar,
            model=model,
            processor=processor,
        )

        with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
            futures = [executor.submit(process_item_partial, item) for item in items_to_process]

            for future in concurrent.futures.as_completed(futures):
                result, error = future.result()

                if result:
                    data_with_cot.append(result)

                # Save checkpoint periodically
                if len(data_with_cot) % 50 == 0 and len(data_with_cot) > processed_count:
                    with open(output_json, "w", encoding="utf-8") as f:
                        json.dump(data_with_cot, f, ensure_ascii=False, indent=4)
                    print(f"\nCheckpoint saved at {len(data_with_cot)} items.")

    # --- Final save ---
    with open(output_json, "w", encoding="utf-8") as f:
        json.dump(data_with_cot, f, ensure_ascii=False, indent=4)

    print(f"\nProcessing complete. Processed dataset saved to: {output_json}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate Chain-of-Thought reasoning for medical VQA datasets.")
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
        choices=["slake", "path_vqa", "vqa_rad"],
        help="The name of the dataset to process.",
    )
    parser.add_argument(
        "--model_path",
        default="/sda/duyuetian/models/Qwen2.5-VL-7B-Instruct",
        type=str,
        help="Path to the pretrained Qwen-VL model.",
    )
    parser.add_argument(
        "--image_dir",
        type=str,
        default=None,
        help="Path to the image directory. If not provided, a default for the specified dataset will be used.",
    )
    parser.add_argument(
        "--input_json",
        type=str,
        default=None,
        help="Path to the input JSON file. If not provided, a default for the specified dataset will be used.",
    )
    parser.add_argument(
        "--output_json",
        type=str,
        default=None,
        help="Path to the output JSON file. If not provided, a default for the specified dataset will be created.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=None,
        help="Number of samples to process from the dataset (for testing purposes).",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=2,
        help="Maximum number of worker threads for parallel processing.",
    )
    
    parsed_args = parser.parse_args()
    main(parsed_args)