import os
import json
import argparse
import pandas as pd
from datasets import load_from_disk, load


def build_prompt_chatbot(problems, sceneGraphs, is_bounding_box=False):
    examples = {}
    for idx in range(1, len(problems)):
        entity = problems.iloc[idx]["ambiguous_entity"]
        entity_id = str(problems.iloc[idx]["entity_id"])
        image_id = str(problems.iloc[idx]["image_id"])
        q_id = problems.iloc[idx]["q_id"]
        additional_question = problems.iloc[idx]["additional_question"]
        label = problems.iloc[idx]["label"]

        target_entity = sceneGraphs[image_id]["objects"][entity_id]
        target_entity_name = target_entity["name"]
        width = sceneGraphs[image_id]["width"]
        height = sceneGraphs[image_id]["height"]

        if is_bounding_box is True:
            bounding_boxs = []
            for object_id, object_value in sceneGraphs[image_id]["objects"].items():
                if object_value["name"] == target_entity_name:
                    x = object_value["x"] / width
                    y = object_value["y"] / height
                    w = x + object_value["w"] / width
                    h = y + object_value["h"] / height
                    bounding_boxs.append(
                        f"{target_entity_name}: [{x:.3f}, {y:.3f}, {w:.3f}, {h:.3f}]"
                    )
                    if object_id == entity_id:
                        target_entity = f"{target_entity_name}: [{x:.3f}, {y:.3f}, {w:.3f}, {h:.3f}]"
            bounding_boxs_context = ",".join(bounding_boxs)

            input = (
                bounding_boxs_context
                + f"\n Target Entity: {target_entity}"
                + f"\n Sub-Question: {additional_question}"
                + f"\n Question: Does the sub-question classify the target entity? Answer:"
            )
        else:
            input = f"Entity: {entity}\nAnswer:\n"

        answer = "yes" if label == "O" else "no"
        output = f"{answer}"
        input = input.replace("  ", " ").strip()
        output = output.replace("  ", " ").strip()
        examples[f"{idx}"] = input, output

    return examples


def convert_to_llava(args):

    base_dir = args.base_dir
    split = args.split
    bounding_box = args.bounding_box
    problems = pd.read_csv(os.path.join(base_dir, f"GQA_Q2Q_{split}.csv"))
    sceneGraphs_f = open(os.path.join(base_dir, "train_sceneGraphs.json"), "r")
    sceneGraphs = json.load(sceneGraphs_f)

    sceneGraphs_f.close()

    split_problems = build_prompt_chatbot(problems, sceneGraphs, bounding_box)
    
    # image_set = set()
    
    if split == "train":
        target_format = []
        for prob_id, (input, output) in split_problems.items():
            
            raw_prob_data = problems.iloc[int(prob_id)]
                
            
            if input.startswith("Question: "):
                input = input.replace("Question: ", "")
            if output.startswith("Answer: "):
                output = output.replace("Answer: ", "")

            
            if raw_prob_data["image_id"] is None:
                target_format.append(
                    {
                        "id": prob_id,
                        "conversations": [
                            {"from": "human", "value": f"{input}"},
                            {"from": "gpt", "value": f"{output}"},
                        ],
                    }
                )

            else:
                target_format.append(
                    {
                        "id": prob_id,
                        "image": str(raw_prob_data["image_id"]) + ".jpg",
                        "conversations": [
                            {"from": "human", "value": f"<image>\n{input}"},
                            {"from": "gpt", "value": f"{output}"},
                        ],
                    }
                )
                
            
        
        print(f"Number of samples: {len(target_format)}")
        # print(f"Number of images: {len(image_set)}")
        file_name = f"llava_checker_{split}"
        
        # if len(target_format) < len(problems):
        #     file_name = file_name + f'{len(target_format)}'
        
        if bounding_box is True:
            file_name += "_bb"
        with open(os.path.join(base_dir, file_name + ".json"), "w") as f:
            json.dump(target_format, f, indent=2)
    else:
        file_name = f"llava_checker_{split}"
        if bounding_box is True:
            file_name += "_bb"
        q_writer = open(os.path.join(base_dir, file_name + "_question.jsonl"), "w")
        a_writer = open(os.path.join(base_dir, file_name + "_answer.jsonl"), "w")
        for prob_id, (input, output) in split_problems.items():
            if input.startswith("Question: "):
                input = input.replace("Question: ", "")
            if output.startswith("Answer: "):
                output = output.replace("Answer: ", "")

            raw_prob_data = problems.iloc[int(prob_id)]
            if raw_prob_data["image_id"] is None:
                question_data = {
                    "question_id": prob_id,
                    "text": f"{input}",
                    "category": "conv",
                }

                answer_data = {
                    "question_id": prob_id,
                    "text": f"{output}",
                    "category": "conv",
                }

            else:
                question_data = {
                    "question_id": prob_id,
                    "image": str(raw_prob_data["image_id"]) + ".jpg",
                    "text": f"{input}",
                    "category": "conv",
                }

                answer_data = {
                    "question_id": prob_id,
                    "text": f"{output}",
                    "category": "conv",
                }

            q_writer.write(json.dumps(question_data) + "\n")
            a_writer.write(json.dumps(answer_data) + "\n")

        q_writer.close()
        a_writer.close()


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--base_dir", type=str)
    parser.add_argument("--split", default="train", type=str)
    parser.add_argument("--bounding_box", default=False, type=bool)
    parser.add_argument("--image_num", default=0, type=int )

    args = parser.parse_args()
    convert_to_llava(args)
