import argparse
import os
import json
import shutil

def get_args():
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--src_data_dir", type=str, required=True)
    parser.add_argument("--tgt_data_dir", type=str, required=True)
    parser.add_argument("--src_data_name", type=str, required=True)
    parser.add_argument("--tgt_data_name", type=str, required=True)
    parser.add_argument("--length", type=str, default="8k")
    parser.add_argument("--data_info_file_path", type=str, required=True)   
    
    return parser.parse_args()

def llama_factory_data_register(data_name,data_info_file_path):
    with open(data_info_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    if "_s2l_" in data_name:
        data[data_name] = {
            "file_name": f"{data_name}.json",
            "ranking": True,
            "formatting": "sharegpt",
            "columns": {
            "messages": "conversations",
            "chosen": "chosen",
            "rejected": "rejected",
            "long_messages": "long_conversations"
            }
        }
    
    if "_po_" in data_name:
        data[data_name] = {
            "file_name": f"{data_name}.json",
            "ranking": True,
            "formatting": "sharegpt",
            "columns": {
            "messages": "conversations",
            "chosen": "chosen",
            "rejected": "rejected"
            }
        }

    if "_sft_" in data_name:
        data[data_name] = {
            "file_name": f"{data_name}.json",
            "formatting": "sharegpt",
            "columns": {
            "messages": "conversations"
            }
        }

    with open(data_info_file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def main():
    args = get_args()
    print("==========new data name==========")
    for prefix in [f"{args.length}_po_long", f"{args.length}_po_short", f"{args.length}_sft_long", f"{args.length}_sft_short", f"{args.length}_s2l_short2long_short"]:
        tmp_src_file_path = f"{args.src_data_dir}/{prefix}_{args.src_data_name}.json"
        if os.path.exists(tmp_src_file_path):
            # rename
            tmp_src_new_file_path = f"{args.src_data_dir}/{prefix}_{args.tgt_data_name}.json"
            os.rename(tmp_src_file_path, tmp_src_new_file_path)
            # cp
            tmp_tgt_save_file_path = f"{args.tgt_data_dir}/{prefix}_{args.tgt_data_name}.json"
            shutil.copy(tmp_src_new_file_path, tmp_tgt_save_file_path)
        # register
        if os.path.exists(f"{args.tgt_data_dir}/{prefix}_{args.tgt_data_name}.json"):
            llama_factory_data_register(f"{prefix}_{args.tgt_data_name}",args.data_info_file_path)
            print(f"{prefix}_{args.tgt_data_name}")
    print("==========new data name==========")

if __name__ == "__main__":
    main()