import json
import argparse
from tqdm import tqdm


def transform_item(item):
    """
    Transform a single item into the desired format.
    """
    # Ensure images are a list

    # Transform conversations into messages
    messages = [
        item["messages"]
    ]

    return {
        "messages": messages[0],
        "images": item["images"]
    }


def process_json(input_file, output_file):
    """
    Read the input JSON file, process each item, and write the transformed items to the output file.
    """
    with open(input_file, "r", encoding="utf-8") as infile:
        data = json.load(infile)

    transformed_data = []
    for item in tqdm(data, desc="Processing items"):
        transformed_data.append(transform_item(item))
        # print(transformed_data[-1])

    with open(output_file, "w", encoding="utf-8") as outfile:
        json.dump(transformed_data, outfile, indent=4, ensure_ascii=False)

    print(f"Transformation complete! Output saved to {output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Transform JSON items into the desired format.")
    parser.add_argument("--input", default="/home/duyuetian/projects/MedVLM-R1/dataset/cot/qwen_slake_train_cot_cleaned_sft.json", help="Path to the input JSON file.")
    parser.add_argument("--output", default="/home/duyuetian/projects/MedVLM-R1/dataset/think/qwen2vl_format_think_qwen_medical_sft_slake.json", help="Path to the output JSON file.")
    args = parser.parse_args()

    process_json(args.input, args.output)
