prompt_layout_OG = ("Your task is to generate the bounding boxes of objects mentioned in the caption, along with direction that objects facing. "
              "The image is size 512x512."
              "The bounding box should be in the format of (x, y, width, height) from 0 to 1."
              "The direction that object is facing should be one of these options, [front, back, left, right]"
              "Please considering the frame of reference of caption and direction of reference object."
              "The answer should be in the form of \"Reasoning: Explanation\nLayout: Layout\" The example of layout is [(cat, [0.1, 0.3, 0.5, 0.4], right), (cow, [0.6, 0.5, 0.3, 0.4], right)]")


import os
import sys
sys.path.append("../../SLD")
sys.path.append("")
import ast
from dotenv import load_dotenv

load_dotenv()

import argparse
import torch
import json
import copy
from tqdm import tqdm
from sld.llm_template import spot_difference_template_FoR2
from qwen_models import Qwen2_5Model, Qwen3Model

    

def main(args):

    cur_device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"
    # print(cur_device)
    

    suffix = "" if args.reasoning_tokens else "_no-reasoning-token"
    
    suffix += ""
    model = "qwen3" if args.model.lower() == "qwen3" else "qwen2.5"
    update_file = f"results/FoR_editing_benchmark_convert_prompt_round-0_{model}{suffix}.json"
    with open(update_file, 'r') as file:
        benchmark = json.load(file)["data"]

    # update_file = f"results/FoR_editing_benchmark_convert_prompt_round-0_{model}{suffix}.json"
    # with open(update_file, 'r') as file:
    #     benchmark = benchmark + json.load(file)["data"]

    update_infos = []

    # save file
    suffix += "_update_prompt" if args.use_update_prompt else ""
    save_file = f"results/benchmark_{model}{suffix}_generate_layout.json"
    print("Saving result at", save_file)
    if args.extract_results:

        with open(save_file, 'r') as file:
            benchmark = json.load(file)["data"]
        new_data = []
        count = 0
        for data in tqdm(benchmark[:], desc="Inference"):
            try:
                raw_response = data["update_layout_raw"]
                bbox_data = raw_response.split("Updated Objects")[1]
                start_index = bbox_data.index("[")
                end_index = bbox_data.rindex("]") + 1
                bbox_str = bbox_data[start_index:end_index]
                updated_bboxes = ast.literal_eval(bbox_str)
                print("Found updated layout")
                count += 1
            except:
                updated_bboxes = None
            
            new_info = copy.deepcopy(data)
            new_info["llm_layout_suggestions"] = updated_bboxes
            new_data.append(new_info)
        print(count)
        json.dump({"data": new_data}, open(save_file, 'w'), indent=3)

        return
    
    # load model
    print(f"Loading and Infering from {model}")
    llm_model = Qwen3Model(model_size=args.model_size, device=cur_device, enable_reasoning=args.reasoning_tokens) \
        if model == "qwen3" else Qwen2_5Model(model_size=args.model_size, device=cur_device)
    
    for data in tqdm(benchmark[:], desc="Inference"):
        prompt = data.get("update_prompt", data["prompt"]) if args.use_update_prompt else data["prompt"]
        cur_info = (
            f"User Prompt: {prompt}\n"
        )
        messages = [
            {"role": "user", "content": prompt_layout_OG + cur_info}
        ]
        thinking, update = llm_model(messages, enable_thinking=args.reasoning_tokens)
        update_data = copy.deepcopy(data)
        update_data["layout_thinking"] = thinking
        update_data["layout_raw"] = update
        update_infos.append(update_data)

    json.dump({"data": update_infos}, open(save_file, 'w'), indent=3)
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, help="CUDA device to use", default=0)
    parser.add_argument("--model", type=str, help="size of LLM", default="Qwen3")
    parser.add_argument("--model_size", type=str, help="size of LLM", default="32B")
    parser.add_argument("--reasoning_tokens", help="enable reasoning token", action="store_true")
    parser.add_argument("--extract_results", help="enable reasoning token", action="store_true")
    parser.add_argument("--use_update_prompt", help="enable using update prompt instead", action="store_true")
    parser.add_argument("--first_half", help="Only query the first half of benchmark", action="store_true")
    parser.add_argument("--second_half", help="Only query the second half of benchmark", action="store_true")
    args = parser.parse_args()

    main(args)