
BASE_PATH="/data1/yubnub/data/style_transfer/ICLR"

HUMAN = [
    f"{BASE_PATH}/HUMAN/MTD_amazon_12000.jsonl.ready",
    f"{BASE_PATH}/HUMAN/MTD_blogs_7000.jsonl.ready",
    f"{BASE_PATH}/HUMAN/MTD_reddit_12000_correct.jsonl.ready",
    "content_text",
]
MACHINE = [
    f"{BASE_PATH}/MACHINE/MTD_amazon_12000.jsonl.ready",
    f"{BASE_PATH}/MACHINE/MTD_blogs_7000.jsonl.ready",
    f"{BASE_PATH}/MACHINE/MTD_reddit_12000_correct.jsonl.ready",
    "respond_reddit",
]
LLMOPT = [
    f"{BASE_PATH}/LLMOPT/MTD-amazon-12000_checkpoint-7500_merged-FastDetectGPT-amazon_temperature=0.7_top-p=0.9_ng=2-preference.jsonl",
    f"{BASE_PATH}/LLMOPT/MTD-blogs-7000_checkpoint-4500_merged-FastDetectGPT-blogs_temperature=0.7_top-p=0.9_ng=2-preference.jsonl",
    f"{BASE_PATH}/LLMOPT/MTD-reddit-12000-correct_checkpoint-7500_merged-FastDetectGPT-reddit_temperature=0.7_top-p=0.9_ng=2-preference.jsonl",
    "respond_reddit",
]
PARAPHRASING = [
    f"{BASE_PATH}/PARAPHRASING/amazon_paraphrase.jsonl",
    f"{BASE_PATH}/PARAPHRASING/blogs_paraphrase.jsonl",
    f"{BASE_PATH}/PARAPHRASING/reddit_paraphrase.jsonl",
    "generation",
]
PROMPTING = [
    f"{BASE_PATH}/PROMPTING/amazon_prompting.jsonl",
    f"{BASE_PATH}/PROMPTING/blogs_prompting.jsonl",
    f"{BASE_PATH}/PROMPTING/reddit_prompting.jsonl",
    "generation",
]
DIPPER = [
    f"{BASE_PATH}/DIPPER/amazon_dipper.jsonl",
    f"{BASE_PATH}/DIPPER/blogs_dipper.jsonl",
    f"{BASE_PATH}/DIPPER/reddit_dipper.jsonl",
    "paraphrase_dipper",
]
OUTFOX = [
    f"{BASE_PATH}/OUTFOX/amazon.jsonl",
    f"{BASE_PATH}/OUTFOX/blogs.jsonl",
    f"{BASE_PATH}/OUTFOX/reddit.jsonl",
    "generation",
]
TINYSTYLER = [
    f"{BASE_PATH}/TINYSTYLER/MTD_amazon_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl",
    f"{BASE_PATH}/TINYSTYLER/MTD_blogs_7000_Mistral-7B-Instruct-v0.3_N=5.jsonl",
    f"{BASE_PATH}/TINYSTYLER/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl",
    "transfer_pick",
]
OURS_NO_DPO=[ # only Reddit, be careful
    "./neurips/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl.nodpo",
    "./neurips/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl.nodpo",
    "./neurips/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl.nodpo",
    "transfer_pick",
]
DG = [ # only Reddit, be careful
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_generations.jsonl",
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_generations.jsonl",
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_generations.jsonl",
    "selected",
]
DG04 = [ # only Reddit, be careful
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_0.4_generations.jsonl",
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_0.4_generations.jsonl",
    "/data1/yubnub/data/style_transfer/olivia/fullmgt/MTD_D+G_0.4_generations.jsonl",
    "selected",
]
OURS = [
    f"{BASE_PATH}/OURS/MTD_amazon_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3",
    f"{BASE_PATH}/OURS/MTD_blogs_7000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3",
    f"{BASE_PATH}/OURS/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3",
    "transfer_pick",
]
OURS_NE8 = [
    "./neurips/MTD_amazon_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3.ne=8",
    "./neurips/MTD_blogs_7000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3.ne=8",
    "./neurips/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3.ne=8",
    "transfer_pick",
]

if __name__ == "__main__":
    SCRIPT="""#!/bin/sh
export CUDA_VISIBLE_DEVICES={}
"""
    # validating paths and fields
    import os; import json
    vars = dict(locals())
    cuda = 0
    ii = 0
    methods = ["HUMAN", "MACHINE", "LLMOPT", "PARAPHRASING", "PROMPTING", "DIPPER", "OUTFOX", "TINYSTYLER", "OURS"]
    for key, value in vars.items():
        if key.startswith("__") or not key[0].isupper() or key == "SCRIPT" or key == "BASE_PATH":
            continue
        if key not in methods:
            continue
        current = SCRIPT.format(cuda)
        for path in value[:-1]:
            assert os.path.exists(path)
            d = json.loads(open(path).readline())
            assert value[-1] in d, path
            assert "respond_reddit" in d or "generation" in d, path
            
            line = "\npython mtd_score.py --mtd_data_path {} --eval_key {}\n"
            line = line.format(path, value[-1])
            current += line

        with open("./mtd_score_scripts/{}_{}.sh".format(cuda, ii), "w+") as fout:
            fout.write(current)

        cuda += 1
        if cuda > 3:
            cuda = 0
        ii += 1
