import os
import argparse
import json
import random
import copy
import collections
import io
import base64

import numpy as np
import torch
import transformers
from PIL import Image
import tqdm
import more_itertools

parser = argparse.ArgumentParser()

parser.add_argument('--pretrained-model-name', default="Qwen/Qwen2.5-VL-7B-Instruct")
parser.add_argument('--dtype', default="bfloat16")
parser.add_argument('--runner', default="huggingface", choices=["huggingface", "vllm", "openai_batch_export"])

parser.add_argument('--images-path')
parser.add_argument('--icl-examples-path')
parser.add_argument('--annotations-consolidated-path')
parser.add_argument('--annotations-path')

parser.add_argument('--output-path')
parser.add_argument('--cache-dir', default=None)
parser.add_argument('--image-within-examples', default=False, action="store_true")
parser.add_argument('--disable-flash-attention', default=False, action="store_true")
parser.add_argument('--enable-flash-attention-2', default=False, action="store_true")
parser.add_argument('--batch-size', default=1, type=int)
parser.add_argument('--seed', default=42, type=int)

parser.add_argument('--generate-explanation-first', default=False, action="store_true")
parser.add_argument('--enable-step-by-step', default=False, action="store_true")
parser.add_argument('--enable-additional-conditions', default=False, action="store_true")
parser.add_argument('--enable-additional-condition-relevant', default=False, action="store_true")
parser.add_argument('--no-bbox', default=False, action="store_true")

parser.add_argument('--split-by', default=-1, type=int)
parser.add_argument('--split', default=0, type=int)

args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)


def get_instances(sentences: list[str], annotations: list[dict[str, str|bool]]) -> list[list[str|tuple[int, int]]]:
    """
    Merge corrections into a full paragraph based on annotations.

    Args:
        sentences (list): List of original sentences
        annotations (list): List of annotation objects with 'correct' and 'correction' keys

    Returns:
        list: A list of self-critique instances with the format:
                [
                    [(start, end), "human_correction", "context"]
                    ...
                ]
    """
    # Check if all lines are marked as incorrect; if so, return empty list
    all_incorrect = all(not ann['correct'] for ann in annotations)
    if all_incorrect:
        return []

    # Process line by line for partial corrections
    output = []
    result = []
    i = 0
    while i < len(sentences):
        if annotations[i]['correct']:
            # If line is correct, use original
            result.append(sentences[i])
            i += 1
        else:
            # Find consecutive incorrect lines
            start = i
            end = i

            while end + 1 < len(annotations) and not annotations[end + 1]['correct']:
                end += 1

            # Concatenate corrections for the incorrect lines
            correction = ""
            for j in range(start, end + 1):
                if annotations[j]['correction'].strip():
                    _raw_corr = annotations[j]['correction'].strip()
                    if not _raw_corr.endswith('.'):
                        _raw_corr += '.'
                    correction += _raw_corr + " "
            
            # Concatenate the original sentences for the incorrect lines
            incorrect_lines = " ".join(sentences[start:end + 1])

            # If no correction is provided, means the annotator is intended to delete these lines so we can skip
            if correction:
                # Add the correction to the result
                result.append(correction.strip())
            
            output.append(
                [
                    (start, end),
                    correction.strip(),
                    " ".join(result) + " " + incorrect_lines
                ]
            )

            # Move past all the incorrect lines we've processed
            i = end + 1
    
    return output

# Read ICL examples
# Need to do this first due to vLLM limit_mm_per_prompt below....
icl_examples = {}

for icl_fname in os.listdir(args.icl_examples_path):
    if icl_fname.endswith(".txt"):
        icl_examples[icl_fname.split(".")[0]] = {}

        with open(os.path.join(args.icl_examples_path, icl_fname)) as icl_f:            
            icl_examples[icl_fname.split(".")[0]]["explanation"] = icl_f.read()
    
        with open(os.path.join(args.icl_examples_path, icl_fname.split(".")[0] + ".bbox")) as icl_f_bbox:            
            icl_examples[icl_fname.split(".")[0]]["num_bboxes"] = int(icl_f_bbox.read())

# Torch dtype
if args.dtype == "float32":
    dtype = torch.float32
elif args.dtype == "float16":
    dtype = torch.float16
else:
    dtype = torch.bfloat16

if args.runner == "vllm":
    import vllm

    # Try to get the default generation config from HF repo to be loaded properly
    try:
        generation_config = transformers.GenerationConfig.from_pretrained(args.pretrained_model_name, cache_dir=args.cache_dir)
        gen_config_dict = generation_config.to_diff_dict()
    except:
        gen_config_dict = {}

    if args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
        mm_processor_kwargs = {"max_dynamic_patch": 3}
    elif args.pretrained_model_name.startswith("Qwen/Qwen2.5-VL") or args.pretrained_model_name.startswith("Qwen/QVQ"):
        mm_processor_kwargs = {"max_pixels": 768*28*28}
    else:
        mm_processor_kwargs = None
    
    # Phi-4 specific model loading
    # https://github.com/vllm-project/vllm/blob/25f560a62c4f955672e2c6080b17ab3a48f96201/examples/offline_inference/vision_language_multi_image.py#L325
    if args.pretrained_model_name.startswith("microsoft/Phi-4-multimodal-instruct"):
        from huggingface_hub import snapshot_download
        from vllm.lora.request import LoRARequest

        model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct", cache_dir=args.cache_dir)
        vision_lora_path = os.path.join(model_path, "vision-lora")
    else:
        model_path = args.pretrained_model_name

    engine_args = {"model": model_path, "download_dir": args.cache_dir, "gpu_memory_utilization": 0.99, "mm_processor_kwargs": mm_processor_kwargs, "limit_mm_per_prompt": {"image": len(icl_examples) + 1}, "max_num_seqs": args.batch_size, "quantization": "bitsandbytes", "load_format": "bitsandbytes", "tensor_parallel_size": torch.cuda.device_count(), "override_generation_config": gen_config_dict, "trust_remote_code": True}

    if "awq" in args.pretrained_model_name.lower():
        engine_args["quantization"] = "awq"
        engine_args["load_format"] = "auto"

    if args.pretrained_model_name.startswith("microsoft/Phi-4-multimodal-instruct"):
        engine_args["enable_lora"] = True
        engine_args["max_lora_rank"] = 320
        engine_args["dtype"] = "auto"
    elif args.pretrained_model_name.startswith("Qwen/Qwen2.5-VL") or args.pretrained_model_name.startswith("Qwen/QVQ"):
        engine_args["max_model_len"] = 32768
        engine_args["dtype"] = dtype
    elif "Llama-3.2" in args.pretrained_model_name:
        engine_args["max_model_len"] = 38500
        engine_args["dtype"] = dtype
    elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
        engine_args["hf_overrides"] = {"architectures": ["DeepseekVLV2ForCausalLM"]}
        engine_args["max_model_len"] = 4096
        engine_args["dtype"] = dtype
    else:
        engine_args["max_model_len"] = 8192
        engine_args["dtype"] = dtype

    model = vllm.LLM(**engine_args)

    if args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
        processor = transformers.AutoTokenizer.from_pretrained(args.pretrained_model_name, trust_remote_code=True, use_fast=False, padding_side='left')
    else:
        processor = transformers.AutoProcessor.from_pretrained(args.pretrained_model_name, max_pixels=768*28*28, padding_side='left', trust_remote_code=True, cache_dir=args.cache_dir, use_fast=True)
elif args.runner == "huggingface":
    torch.manual_seed(args.seed)
    torch.use_deterministic_algorithms(True, warn_only=True)

    if torch.cuda.is_available():
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
        torch_device = torch.device("cuda")
        device_map = "cuda"

        if args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
            quantization_config = transformers.QuantoConfig(weights="int4")
        else:
            quantization_config = transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='nf4'
            )
    elif torch.backends.mps.is_available():
        torch_device = torch.device("mps")
        device_map = "mps"

        quantization_config = transformers.QuantoConfig(weights="int4")
    else:
        device_map = "cpu"

    if args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
        from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
        from deepseek_vl2.utils.io import load_pil_images

        processor = DeepseekVLV2Processor.from_pretrained(args.pretrained_model_name, cache_dir=args.cache_dir)
        tokenizer = processor.tokenizer

        model = DeepseekVLV2ForCausalLM.from_pretrained(args.pretrained_model_name, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, quantization_config=quantization_config, device_map=device_map, cache_dir=args.cache_dir).eval()
    elif args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
        import math
        import torchvision.transforms as T
        from PIL import Image
        from torchvision.transforms.functional import InterpolationMode

        IMAGENET_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_STD = (0.229, 0.224, 0.225)

        def build_transform(input_size):
            MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
            transform = T.Compose([
                T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
                T.ToTensor(),
                T.Normalize(mean=MEAN, std=STD)
            ])
            return transform

        def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
            best_ratio_diff = float('inf')
            best_ratio = (1, 1)
            area = width * height
            for ratio in target_ratios:
                target_aspect_ratio = ratio[0] / ratio[1]
                ratio_diff = abs(aspect_ratio - target_aspect_ratio)
                if ratio_diff < best_ratio_diff:
                    best_ratio_diff = ratio_diff
                    best_ratio = ratio
                elif ratio_diff == best_ratio_diff:
                    if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                        best_ratio = ratio
            return best_ratio

        def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
            orig_width, orig_height = image.size
            aspect_ratio = orig_width / orig_height

            # calculate the existing image aspect ratio
            target_ratios = set(
                (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
                i * j <= max_num and i * j >= min_num)
            target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

            # find the closest aspect ratio to the target
            target_aspect_ratio = find_closest_aspect_ratio(
                aspect_ratio, target_ratios, orig_width, orig_height, image_size)

            # calculate the target width and height
            target_width = image_size * target_aspect_ratio[0]
            target_height = image_size * target_aspect_ratio[1]
            blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

            # resize the image
            resized_img = image.resize((target_width, target_height))
            processed_images = []
            for i in range(blocks):
                box = (
                    (i % (target_width // image_size)) * image_size,
                    (i // (target_width // image_size)) * image_size,
                    ((i % (target_width // image_size)) + 1) * image_size,
                    ((i // (target_width // image_size)) + 1) * image_size
                )
                # split the image
                split_img = resized_img.crop(box)
                processed_images.append(split_img)
            assert len(processed_images) == blocks
            if use_thumbnail and len(processed_images) != 1:
                thumbnail_img = image.resize((image_size, image_size))
                processed_images.append(thumbnail_img)
            return processed_images

        def load_image(image_file, input_size=448, max_num=12):
            image = Image.open(image_file).convert('RGB')
            transform = build_transform(input_size=input_size)
            images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
            pixel_values = [transform(image) for image in images]
            pixel_values = torch.stack(pixel_values)
            return pixel_values

        def split_model(model_name):
            device_map = {}
            world_size = torch.cuda.device_count()
            config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True, cache_dir=args.cache_dir)
            num_layers = config.llm_config.num_hidden_layers
            # Since the first GPU will be used for ViT, treat it as half a GPU.
            num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
            num_layers_per_gpu = [num_layers_per_gpu] * world_size
            num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
            layer_cnt = 0
            for i, num_layer in enumerate(num_layers_per_gpu):
                for j in range(num_layer):
                    device_map[f'language_model.model.layers.{layer_cnt}'] = i
                    layer_cnt += 1
            device_map['vision_model'] = 0
            device_map['mlp1'] = 0
            device_map['language_model.model.tok_embeddings'] = 0
            device_map['language_model.model.embed_tokens'] = 0
            device_map['language_model.output'] = 0
            device_map['language_model.model.norm'] = 0
            device_map['language_model.model.rotary_emb'] = 0
            device_map['language_model.lm_head'] = 0
            device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

            return device_map

        # Load the model
        model = transformers.AutoModel.from_pretrained(
            args.pretrained_model_name,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            use_flash_attn=device_map == "cuda",
            trust_remote_code=True,
            quantization_config=quantization_config, device_map=split_model(args.pretrained_model_name), cache_dir=args.cache_dir).eval()

        tokenizer = transformers.AutoTokenizer.from_pretrained(args.pretrained_model_name, trust_remote_code=True, use_fast=False, padding_side='left')

        generation_config = dict(max_new_tokens=1024, do_sample=True)
    else:
        processor = transformers.AutoProcessor.from_pretrained(args.pretrained_model_name, max_pixels=768*28*28, padding_side='left', trust_remote_code=True, cache_dir=args.cache_dir)

        if args.disable_flash_attention:
            attn_implementation = "eager"
        else:
            if device_map == "cuda" and args.enable_flash_attention_2:
                attn_implementation = "flash_attention_2"
            else:
                attn_implementation = "sdpa"

        if args.pretrained_model_name.startswith("microsoft/Phi-4-multimodal-instruct"):
            transformers_model_class = transformers.AutoModelForCausalLM
        else:
            transformers_model_class = transformers.AutoModelForImageTextToText

        model = transformers_model_class.from_pretrained(args.pretrained_model_name,
                                                        torch_dtype=dtype,
                                                        attn_implementation=attn_implementation,
                                                        quantization_config=quantization_config,
                                                        device_map=device_map,
                                                        trust_remote_code=True,
                                                        cache_dir=args.cache_dir)

        generation_config = transformers.GenerationConfig.from_pretrained(args.pretrained_model_name, cache_dir=args.cache_dir)

# Read questions.json
questions = json.load(open(args.annotations_path, "r"))

if args.image_within_examples:
    examples_content = []
else:
    examples_text = ""

image_inputs_icl = []

# If we have at least 1 ICL example, prepend it to prompt
if len(icl_examples) > 0:
    if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
        num_patches_list_icl = []

    for i, fname_key in enumerate(icl_examples.keys()):
        if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
            image_icl = load_image(os.path.join(args.icl_examples_path, fname_key + ".jpg"), max_num=3).to(dtype)
        elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
            image_icl = os.path.join(args.icl_examples_path, fname_key + ".jpg")
        else:
            image_icl = Image.open(os.path.join(args.icl_examples_path, fname_key + ".jpg"))

        image_inputs_icl.append(image_icl)

        if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
            num_patches_list_icl.append(image_icl.size(0))

        if args.image_within_examples:
            # Header text before image token
            examples_content.append({"type": "text", "text": "## Image {im_num} ({num_bbox} bounding box(es))\n".format(im_num=i+1, num_bbox=icl_examples[fname_key]["num_bboxes"])})

            # Image token
            if args.runner == "openai_batch_export":
                examples_content.append({"type": "image_url", "image_url": {"url": ""}})
            else:
                examples_content.append({"type": "image"})

            # After image token
            if args.generate_explanation_first:
                examples_content.append({"type": "text", "text": "**Question**: {question}\n**Thought Process**: {explanation}\n**Answer**: {answer}\n\n".format(question=questions[fname_key]["question"], answer=questions[fname_key]["fullAnswer"], explanation=icl_examples[fname_key]["explanation"])})
            else:
                examples_content.append({"type": "text", "text": "**Question**: {question}\n**Thought Process**: {explanation}\n**Answer**: {answer}\n\n".format(question=questions[fname_key]["question"], answer=questions[fname_key]["fullAnswer"], explanation=icl_examples[fname_key]["explanation"])})
        else:
            if args.generate_explanation_first:
                examples_text += "## Image {im_num} ({num_bbox} bounding box(es))\n**Question**: {question}\n**Answer**: {answer}\n**Thought Process**: {explanation}\n\n".format(im_num=i+1, num_bbox=icl_examples[fname_key]["num_bboxes"], question=questions[fname_key]["question"], answer=questions[fname_key]["fullAnswer"], explanation=icl_examples[fname_key]["explanation"])
            else:
                examples_text += "## Image {im_num} ({num_bbox} bounding box(es))\n**Question**: {question}\n**Answer**: {answer}\n**Thought Process**: {explanation}\n\n".format(im_num=i+1, num_bbox=icl_examples[fname_key]["num_bboxes"], question=questions[fname_key]["question"], answer=questions[fname_key]["fullAnswer"], explanation=icl_examples[fname_key]["explanation"])

# Main instruction
    if args.no_bbox:
        main_instruction = "Similarily to the examples provided above, please write down the thought process for solving the question, that concludes with the final answer."
    else:
        main_instruction = "Similarily to the examples provided above, please imagine this person's thought process behind the areas of the image the person looked at in order to solve the question, which are marked with the red bounding boxes. "

    if args.image_within_examples:
        examples_content.append({"type": "text", "text": main_instruction})
    else:
        prompt_template = examples_text + main_instruction
else:
    if args.no_bbox:
        main_instruction = "Please write down the thought process for solving the question, that concludes with the final answer."
    else:
        main_instruction = "Please imagine this person's thought process behind the areas of the image the person looked at in order to solve the question, which are marked with the red bounding boxes. "
    
    if args.image_within_examples:
        examples_content.append({"type": "text", "text": main_instruction})
    else:
        prompt_template = examples_text + main_instruction

# Detailed conditions
if args.enable_additional_conditions:
    if args.no_bbox:
        additional_conditions_text = (
            "\n\nThe thought process should meet the following conditions:\n"
            "- It should describe the broader setting of the image and how that context informs the search for specific details.\n"
            "- It should clearly connect each observation to the final answer. It should explain how, based on the evidence gathered from the image, the conclusion was reached (e.g., identifying the man holding a cell phone).\n"
        )
    else:
        additional_conditions_text = (
            "\n\nThe thought process should meet the following conditions:\n"
            "- It should describe the broader setting of the image and how that context informs the search for specific details.\n"
            "- It should explain how the bounding boxes (e.g., `R1`, `R2`) were used to narrow down the areas of interest. For instance, it should mention how the image’s orientation or the placement of objects guided the focus to certain regions.\n"
            "- It should only consider the bounding boxes that does appear in the image and not refer to any bounding boxes that actually do not exist. Note that all bounding boxes are numbered sequentially, so if we have 4 bounding boxes in the image, there are exactly the following bounding boxes: `R1`, `R2`, `R3`, and `R4`.\n"
            "- It should list specific visual cues noticed within each highlighted region and how these observations lead to identifying the relevant subject (e.g., a man with a backpack).\n"
            "- It should clearly connect each observation to the final answer. It should explain how, based on the evidence gathered from the image, the conclusion was reached (e.g., identifying the man holding a cell phone).\n"
            "- It should clearly connect each observation to the final answer. It should explain how, based on the evidence gathered from the image, the conclusion was reached (e.g., identifying the man holding a cell phone).\n"
            "- If you realize you wrote something wrong in your thought process previously, you should identify and correct the mistakes you have made.\n"
        )

    if args.enable_additional_condition_relevant or (args.pretrained_model_name.startswith("Qwen/QVQ") and len(icl_examples) > 0):
        additional_conditions_text += "- You only have to consider the bounding boxes that are relevant to answering the question, and there is no need to use every single bounding boxes that appear in the image.\n\n"
    else:
        additional_conditions_text += "\n"

    if args.image_within_examples:
        examples_content.append({
            "type": "text",
            "text": additional_conditions_text})
    else:
        prompt_template = (
            prompt_template + additional_conditions_text
        )

# Writing style
if args.enable_step_by_step:
    if args.image_within_examples:
        examples_content.append({
            "type": "text",
            "text": "Write as if you are this person, in a step-by-step manner."
        })
    else:
        prompt_template = (
            prompt_template +
            "Write as if you are this person, in a step-by-step manner."
        )
else:
    if not args.no_bbox:
        if len(icl_examples) > 0:
            if args.image_within_examples:
                examples_content.append({
                    "type": "text",
                    "text": "Write as if you are this person, in a similar style as above examples."
                })
            else:
                prompt_template = (
                    prompt_template +
                    "Write as if you are this person, in a similar style as above examples."
                )
        else:
            if args.image_within_examples:
                examples_content.append({
                    "type": "text",
                    "text": "Write as if you are this person."
                })
            else:
                prompt_template = (
                    prompt_template +
                    "Write as if you are this person."
                )

# Header
if args.no_bbox:
    header_template_part = "\n\n## Image {im_num}\n"
else:
    header_template_part = "\n\n## Image {im_num} ({num_bbox} bounding box(es))\n"

question_template_part = "**Question**: {question}\n"

if args.image_within_examples:
    examples_content.append({
        "type": "text",
        "text": header_template_part
    })

    if args.runner == "openai_batch_export":
        examples_content.append({"type": "image_url", "image_url": {"url": ""}})
    else:
        examples_content.append({"type": "image"})

    examples_content.append({
        "type": "text",
        "text": question_template_part
    })
else:
    prompt_template = prompt_template + header_template_part + question_template_part

# Answer
if args.generate_explanation_first:
    if args.image_within_examples:
        examples_content.append({
            "type": "text",
            "text": "**Thought Process**: \n**Answer**: "
        })
    else:
        prompt_template += "**Thought Process**: \n"
        prompt_template += "**Answer**: "
else:
    if args.image_within_examples:
        examples_content.append({
            "type": "text",
            "text": "**Answer**: {answer}\n**Thought Process**: "
        })
    else:
        prompt_template += "**Answer**: {answer}\n"
        prompt_template += "**Thought Process**: "

# If the output file already exists, check the completed examples
existing_ids = collections.defaultdict(bool)

if os.path.exists(args.output_path):
    with open(args.output_path) as results_file:
        for line in results_file:
            l = json.loads(line)
            existing_ids[(l["qid"], (l["correction_start"], l["correction_end"]))] = True

# Record the index where the correction takes place for each qid
all_ids_to_process = []

with open(args.annotations_consolidated_path, "r") as cf:
    for li, line in enumerate(cf):
        l = json.loads(line)

        l_all_instances = get_instances(l["sentences"], l["annotations"])
        
        for ci in l_all_instances:
            if ci[0][1] in list(range(len(l["sentences"])))[:2] or ci[0][1] in list(range(len(l["sentences"])))[-2:]:
                continue
            
            all_ids_to_process.append(
                (
                    l["qid"], l["question"], l["ground_truth"], len(l["bounding_box_labels"]), ci[0], ci[2], ci[1]
                )
            )

if args.split_by > 0:
    ids_to_process = more_itertools.divide(args.split_by, all_ids_to_process)[args.split]
else:
    ids_to_process = all_ids_to_process

# Filter out existing ids
ids_to_process = [i for i in ids_to_process if not existing_ids[(i[0], i[4])]]

print(str(len(ids_to_process)), "samples to process.")
print()

print("Generating...")
print()

for batch in tqdm.tqdm(more_itertools.batched(ids_to_process, args.batch_size)):
    messages = []
    
    if len(icl_examples) > 0:
        image_inputs = copy.copy(image_inputs_icl)
    else:
        image_inputs = []

    for d in batch:
        qid, question, answer, num_bbox, _, _, _ = d
        
        # Prepare messages
        message = []
        
        if args.runner == "openai_batch_export":
            message.append(
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are a helpful assistant."}]
                }
            )
        
        if args.image_within_examples:
            message.append(
                {
                    "role": "user",
                    "content": copy.deepcopy(examples_content)
                }
            )
        elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
            message.append(
                {
                    "role": "<|User|>",
                    "content": []
                }
            )
        else:
            message.append(
                {
                    "role": "user",
                    "content": []
                }
            )

        if not args.image_within_examples:
            if args.runner == "openai_batch_export":
                for _ in range(len(image_inputs_icl) + 1):
                    message[-1]["content"].append({"type": "image_url"})
            elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
                message[-1]["images"] = []

                for iiii in range(len(image_inputs_icl)):
                    message[-1]["images"].append(image_inputs_icl[iiii])
            elif not (args.runner == "vllm" and args.pretrained_model_name.startswith("OpenGVLab/InternVL")):
                for _ in range(len(image_inputs_icl) + 1):
                    message[-1]["content"].append({"type": "image"})

        prompt_format_inputs_header = {"im_num": len(image_inputs)+1}

        if not args.no_bbox:
            prompt_format_inputs_header["num_bbox"] = num_bbox
            
        prompt_format_inputs_question = {"question": question}
        prompt_format_inputs_answer = {"answer": answer}
        prompt_format_inputs_wo_answer = prompt_format_inputs_header | prompt_format_inputs_question
        prompt_format_inputs = prompt_format_inputs_wo_answer | prompt_format_inputs_answer        

        if args.generate_explanation_first:            
            if args.image_within_examples:
                message[-1]["content"][-4]["text"] = message[-1]["content"][-4]["text"].format(**prompt_format_inputs_header)
                message[-1]["content"][-2]["text"] = message[-1]["content"][-2]["text"].format(**prompt_format_inputs_question)
            else:
                prompt = prompt_template.format(**prompt_format_inputs_wo_answer)
        else:
            if args.image_within_examples:
                message[-1]["content"][-4]["text"] = message[-1]["content"][-4]["text"].format(**prompt_format_inputs_header)
                message[-1]["content"][-2]["text"] = message[-1]["content"][-2]["text"].format(**prompt_format_inputs_question)
                message[-1]["content"][-1]["text"] = message[-1]["content"][-1]["text"].format(**prompt_format_inputs_answer)
            else:
                prompt = prompt_template.format(**prompt_format_inputs)

        if not args.image_within_examples:
            if args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2") or (args.runner == "vllm" and args.pretrained_model_name.startswith("OpenGVLab/InternVL")):
                for _ in range(len(icl_examples) + 1):
                    prompt = "<image>\n" + prompt

            if args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
                message[-1]["content"] = prompt
            else:
                message[-1]["content"].append({"type": "text", "text": prompt})

        # Image inputs
        if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
            image = load_image(os.path.join(args.images_path, qid + ".jpg"), max_num=3).to(dtype)
        elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
            image = os.path.join(args.images_path, qid + ".jpg")
        else:
            image = Image.open(os.path.join(args.images_path, qid + ".jpg"))

        if args.runner == "vllm" or args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
            image_inputs = []
            
            if len(icl_examples) > 0:
                image_inputs_qid = copy.copy(image_inputs_icl)
            else:
                image_inputs_qid = []

            image_inputs_qid.append(image)
            
            image_inputs.append(image_inputs_qid)

            if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
                num_patches_lists = []
                
                if len(icl_examples) > 0:
                    num_patches_lists_qid = copy.copy(num_patches_list_icl)
                else:
                    num_patches_lists_qid = []

                num_patches_lists_qid.append(image.size(0))

                num_patches_lists.append(num_patches_lists_qid)
        elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
            message[-1]["images"].append(image)
        else:
            image_inputs.append(image)

        messages.append(message)

    # OpenAI batch export specific logic
    if args.runner == "openai_batch_export":
        for mi in range(len(messages)):
            if args.image_within_examples:
                image_url_indexes = [iui for iui in range(len(messages[mi][-1]["content"])) if messages[mi][-1]["content"][iui]["type"] == "image_url"]
            for mii in range(len(image_inputs)):
                # encode all images into base64
                output = io.BytesIO()
                image_inputs[mii].save(output, format="PNG")
                hex_data = output.getvalue()

                base64_data = base64.b64encode(hex_data).decode("utf-8")
                image_uri = f"data:image/png;base64,{base64_data}"
                
                if args.image_within_examples:
                    messages[mi][-1]["content"][image_url_indexes[mii]]["image_url"] = {"url": image_uri}
                else:
                    messages[mi][-1]["content"][mii]["image_url"] = {"url": image_uri}

    # Chat template
    if args.runner == "huggingface" and args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
        texts = []

        if args.image_within_examples:
            for m in messages:
                full_text = ""
                for uc in m[-1]["content"]:
                    if uc["type"] == "image":
                        full_text += "<image>\n"
                    elif uc["type"] == "text":
                        full_text += uc["text"]

                texts.append(full_text)
        else:
            for m in messages:
                texts.append("<image>\n" * (len(icl_examples) + 1) + m[-1]["content"][-1]["text"])

    elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
        texts = []

        for m in messages:
            m.append({"role": "<|Assistant|>", "content": ""})
            pil_images = load_pil_images(m)

            prepare_inputs = processor(
                conversations=m,
                images=pil_images,
                force_batchify=True,
                system_prompt=""
            )

            texts.append(prepare_inputs)
    elif args.pretrained_model_name.startswith("microsoft/Phi-4-multimodal-instruct"):
        placeholders = "".join(f"<|image_{i}|>"
                               for i in range(1, len(icl_examples) + 1 + 1))

        texts = []

        for m in messages:
            texts.append(f'<|user|>{placeholders}{m[0]["content"][-1]["text"]}<|end|><|assistant|>')
    elif not args.runner == "openai_batch_export":
        texts = [
            processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
            for msg in messages
        ]

    # Append whatever the model generated up to the point where the correction was made
    for b_i, d in enumerate(batch):
        _, _, _, _, _, correction_context, _ = d

        texts[b_i] += correction_context

    if args.runner == "vllm":
        texts_vllm = []

        for tvi, text in enumerate(texts):
            texts_vllm.append({"prompt": text, "multi_modal_data": {"image": image_inputs[tvi]}})

        sampling_params = model.get_default_sampling_params()
        sampling_params.max_tokens = 10240 if args.pretrained_model_name.startswith("Qwen/QVQ") else 1024

        if args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
            stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
            sampling_params.stop_token_ids = [processor.convert_tokens_to_ids(i) for i in stop_tokens]

        if args.pretrained_model_name.startswith("microsoft/Phi-4-multimodal-instruct"):
            outputs = model.generate(texts_vllm, sampling_params, lora_request=[LoRARequest("vision", 1, vision_lora_path)])
        else:
            outputs = model.generate(texts_vllm, sampling_params)

        output_texts = [o.outputs[0].text for o in outputs]
    elif args.runner == "huggingface":
        # Generate output
        with torch.no_grad():
            if args.pretrained_model_name.startswith("OpenGVLab/InternVL"):
                output_texts = []

                for ti, text in enumerate(texts):
                    pixel_values = torch.cat(tuple(image_inputs[ti]), dim=0).to(torch_device)
                    output_text = model.chat(tokenizer, pixel_values, text,
                                                    num_patches_list=num_patches_lists[ti],
                                                    generation_config=generation_config)
                    output_texts.append(output_text)
            elif args.pretrained_model_name.startswith("deepseek-ai/deepseek-vl2"):
                output_texts = []
                
                for prepared_input in texts:
                    # run the model to get the response
                    outputs = model.generate(
                        input_ids=prepared_input.input_ids.to(model.device),
                        attention_mask=prepared_input.attention_mask.to(model.device),
                        images=prepared_input.images.to(model.device, dtype=torch.bfloat16),
                        images_seq_mask=prepared_input.images_seq_mask.to(model.device),
                        images_spatial_crop=prepared_input.images_spatial_crop.to(model.device),
                        pad_token_id=tokenizer.eos_token_id,
                        bos_token_id=tokenizer.bos_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        max_new_tokens=1024,
                        do_sample=True,
                        temperature=0.4,
                        top_p=0.9,
                        repetition_penalty=1.1,
                        use_cache=True,
                    )

                    generated_ids_trimmed = [
                        out_ids[len(in_ids):] for in_ids, out_ids in zip(prepared_input.input_ids, outputs)
                    ]

                    answer = tokenizer.decode(generated_ids_trimmed[0].cpu().tolist(), skip_special_tokens=False)
                    output_texts.append(answer)
            else:
                inputs = processor(
                    text=texts,
                    images=image_inputs,
                    padding=True,
                    return_tensors="pt",
                )

                inputs = inputs.to(torch_device)

                # Batch Inference
                generation_args = {"max_new_tokens": 1024, "generation_config": generation_config, "use_cache": True}
                inputs.update(generation_args)
                
                generated_ids = model.generate(**inputs)
                generated_ids_trimmed = [
                    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                output_texts = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )

    # Write to the output file
    if args.runner == "openai_batch_export":
        for j in range(len(batch)):
            with open(args.output_path, "a") as f:
                json.dump(
                    {
                        "custom_id": batch[j],
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": "gpt-4o-mini",
                            "messages": messages[j],
                            "temperature": 1, "store": True
                        }
                    }, f)

                f.write("\n")
    else:
        for j in range(len(batch)):
            with open(args.output_path, "a") as f:
                json.dump(
                    {
                        "qid": batch[j][0],
                        "question": batch[j][1],
                        "answer": batch[j][2],
                        "num_bbox": batch[j][3],
                        "correction_start": batch[j][4][0],
                        "correction_end": batch[j][4][1],
                        "correction_context": batch[j][5],
                        "correction_human": batch[j][6],
                        "output_text": output_texts[j].strip()
                    }, f)

                f.write("\n")
