import json
import os
import fire
import re
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot


def convert_to_llava(base_dir, split, prompt_format="QCM-LEPA"):
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
    problems = json.load(open(os.path.join(base_dir, "problems.json")))

    split_problems = build_prompt_chatbot(
        problems, split_indices, prompt_format,
        use_caption=False, is_test=False)

    target_format = []
    for prob_id, (input, output) in split_problems.items():
        if input.startswith('Question: '):
            input = input.replace('Question: ', '')
        if output.startswith('Answer: '):
            output = output.replace('Answer: ', '')

        raw_prob_data = problems[prob_id]
        if raw_prob_data['image'] is None:
            target_format.append({
                "id": prob_id,
                "conversations": [
                    {'from': 'human', 'value': f"{input}"},
                    {'from': 'gpt', 'value': f"{output}"},
                ],
            })

        else:
            target_format.append({
                "id": prob_id,
                "image": os.path.join(prob_id, raw_prob_data['image']),
                "conversations": [
                    {'from': 'human', 'value': f"{input}\n<image>"},
                    {'from': 'gpt', 'value': f"{output}"},
                ],
            })

    print(f'Number of samples: {len(target_format)}')

    with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
        json.dump(target_format, f, indent=2)


def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
    problems = json.load(open(os.path.join(base_dir, "problems.json")))

    split_problems = build_prompt_chatbot(
        problems, split_indices, prompt_format,
        use_caption=False, is_test=False)

    writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
    for prob_id, (input, output) in split_problems.items():
        if input.startswith('Question: '):
            input = input.replace('Question: ', '')
        if output.startswith('Answer: '):
            output = output.replace('Answer: ', '')

        raw_prob_data = problems[prob_id]
        if raw_prob_data['image'] is None:
            data = {
                "id": prob_id,
                "instruction": f"{input}",
                "output": f"{output}",
            }

        else:
            data = {
                "id": prob_id,
                "image": os.path.join(prob_id, raw_prob_data['image']),
                "instruction": f"{input}\n<image>",
                "output": f"{output}",
            }
        writer.write(json.dumps(data) + '\n')
    writer.close()


def main(task, **kwargs):
    globals()[task](**kwargs)


if __name__ == "__main__":
    # fire.Fire(main)
    base_dir = "../ScienceQA/data/scienceqa"
    # {train,val,minival,test,minitest}
    split = {"train", "val", "minival", "test", "minitest"}
    for s in split:
        convert_to_llava(base_dir, s)
