
import json
import os

import matplotlib.pyplot as plt

base_path = "./data/MTD_reddit_12000_Mistral-7B-Instruct-v0.3_N=5_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_average_style=1.0_content=1.0_N=100/checkpoints"
# paths = {
#     "DPO-1": "./data/MTD_reddit_12000_transfer-short_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_average_style=1.0_content=1.0_N=100/checkpoints",
#     "DPO-2": "./data/MTD_reddit_12000_transfer-short_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_2_average_style=1.0_content=1.0_N=100/checkpoints",
#     "GST": "./data/MTD_reddit_12000_olivia/checkpoints",
# }
paths = {
    "DPO-1": "./data/MTD_reddit_12000_transfer-short_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_average_style=1.0_content=1.0_N=100/checkpoints",
    "DPO-2": "./data/MTD_reddit_12000_transfer-short_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_2_average_style=1.0_content=1.0_N=100/checkpoints",
    "GST": "./data/MTD_reddit_12000_olivia/checkpoints",
}
data = {
    "Human / Paraphrase": "roberta-base_paraphrase_content_text",
    # "SP(Random, 20)": "roberta-base_transfer_text",
    "Human / Instruction-Tuned": "roberta-base_transfer_pick",
    # "DPO-1_SP(Random, 20)": "roberta-base_transfer_text",
    # "DPO-1_SP(Equal, 20)": "roberta-base_transfer_pick",
    # "DPO-2_SP(Random, 20)": "roberta-base_transfer_text",
    # "DPO-2_SP(Equal, 20)": "roberta-base_transfer_pick",
    "Human / D+G": "roberta-base_generated",
}
paths = [
    "./data/MTD_reddit_12000_transfer-short_transfer_N=20_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_2_average_style=1.0_content=1.0_N=100/checkpoints",
    "./data/MTD_reddit_12000_olivia/checkpoints",
]


Xs = list(range(2000, 20000+1, 2000))

j = 0
_ = plt.figure(dpi=300, figsize=(7,7))
for label, foldername in data.items():

    if label != "Human / Paraphrase":
        curr = os.path.join(paths[j], foldername)
        j += 1
    else:
        curr = os.path.join(base_path, foldername)
    
    Ys = []
    for x in Xs:
        # if "DPO" in label:
            # curr = os.path.join(paths[label.split("_")[0]], foldername + "-" + str(x), "results.json")
        # elif "GST" in label:
            # curr = os.path.join(paths["GST"], foldername + "-" + str(x), "results.json")
        # else:
            # curr = os.path.join(base_path, foldername + "-" + str(x), "results.json")
        # if label != "Human / Paraphrase":
        #     print(paths[j])
        #     curr = os.path.join(paths[j], foldername + "-" + str(x), "results.json")
        #     j += 1
        # else:
        #     curr = os.path.join(base_path, foldername + "-" + str(x), "results.json")
        newcurr = curr + "-" + str(x) + "/results.json"
            
        try:
            eval_accuracy = json.loads(open(newcurr).read())["eval_accuracy"]
        except:
            breakpoint()
        Ys.append(eval_accuracy)
    plt.plot(Xs, Ys, label=label)
    
plt.xticks(Xs)
plt.xlabel("#Training Datapoints")
plt.ylabel("Accuracy")
plt.title("Classification of Human vs LLM-Paraphrase / Style Transfer")
plt.grid()
plt.ylim([0.50, 1.0])
plt.legend(loc="lower right")
plt.tight_layout() 
plt.savefig("./quickplot.png")
plt.close()