import json
from tqdm import tqdm
import argparse

MED_QUESTION_PROMPT = '''
Look the given medical image carefully, and complete the tasks below.
Your task:
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.
2. Then provide the correct single-letter choice(A, B, C, D,...) inside <answer>...</answer> tags.
3. No extra information or text outside of these tags.
{Question} 
'''

MED_QUESTION_PROMPT_OPEN = '''
Look the given medical image carefully, and complete the tasks below.
Your task:
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.
2. Then provide the correct answer inside <answer>...</answer> tags.
3. No extra information or text outside of these tags.
{Question} 
'''

MED_QUESTION_PROMPT_NO_IMAGE = '''
Read the following medical case carefully, and complete the tasks below.
Your task:
1. Think through the case step by step, and enclose your reasoning process in <think>...</think> tags.
2. Then provide the final answer in natural language inside <answer>...</answer> tags.
3. No extra information or text outside of these tags.
{Question}
'''


ZERO_SHOT_MED_QUESTION_PROMPT = "{Question}"

def process_json_sft(input_file, output_file):
   with open(input_file, 'r', encoding='utf-8') as f:
      data = json.load(f)
   item_list = []
   for item in tqdm(data):
      item_save = {
         "id": item["id"],
         "images": item["image"],
         "messages": [
            {
               "content": SFT_TRANCE_QUESTION_PROMPT.format(Question=f"<image>\n<image>\n{item['problem']}"),
               "role": "user"
            },
            {
               "content": item["solution"],
               "role": "assistant"
            }
         ]
      },
      item_list.append(item_save[0])

   with open(output_file, 'w', encoding='utf-8') as f:
      json.dump(item_list, f, ensure_ascii=False, indent=4)

def process_json_cot_sft(input_file, output_file):
   with open(input_file, 'r', encoding='utf-8') as f:
      data = json.load(f)
   # data = data[:2000]
   item_list = []
   for item in tqdm(data):
      item_save = {
         "id": item["id"],
         "images": item["image"],
         "messages": [
            {
               # "content": COT_TRANCE_QUESTION_PROMPT.format(Question=f"<image>\n<image>\n{item['problem']}"),
               # "content": MED_QUESTION_PROMPT.format(Question=f"<image>\n{item['problem']}"),
               # "content": ZERO_SHOT_MED_QUESTION_PROMPT.format(Question=f"<image>\n{item['problem']}"),
               "content": MED_QUESTION_PROMPT_OPEN.format(Question=f"<image>\n{item['problem']}"),
               "role": "user"
            },
            {
               # "content": item["cot"],
               # "content": item["answer"],
               "content": "<think>\n"+item["cot"]+"\n</think>\n<answer>\n"+item["answer"]+"\n</answer>",
               "role": "assistant"
            }
         ]
      },
      item_list.append(item_save[0])

   with open(output_file, 'w', encoding='utf-8') as f:
      json.dump(item_list, f, ensure_ascii=False, indent=4)

def process_json_caption_cot_sft(input_file, output_file):
   with open(input_file, 'r', encoding='utf-8') as f:
      data = json.load(f)
   item_list = []
   for item in tqdm(data):
      item_save = {
         "id": item["id"],
         "images": item["image"],
         "messages": [
            {
               "content": TRANCE_QUESTION_CAPTION_PROMPT.format(Question=f"<image>\n<image>\n{item['problem']}"),
               "role": "user"
            },
            {
               "content": item["cot"],
               "role": "assistant"
            }
         ]
      },
      item_list.append(item_save[0])

   with open(output_file, 'w', encoding='utf-8') as f:
      json.dump(item_list, f, ensure_ascii=False, indent=4)

if __name__ == "__main__":
   parser = argparse.ArgumentParser(description="Process dataset with CoT generation")
   parser.add_argument("--input_json", default="/home/duyuetian/projects/MedVLM-R1/dataset/cot/qwen_slake_train_cot_cleaned.json", help="Path to the input JSON file.")
   parser.add_argument("--output_json", default="/home/duyuetian/projects/MedVLM-R1/dataset/cot/qwen_slake_train_cot_cleaned_sft.json", help="Path to the output JSON file.")
   args = parser.parse_args()

   # process_json_sft(args.input_json, args.output_json)
   process_json_cot_sft(args.input_json, args.output_json)
   # process_json_caption_cot_sft(args.input_json, args.output_json)