
NE_files = [
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=20_temp=0.6_top-p=0.9_ne=1_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=20_temp=0.6_top-p=0.9_ne=2_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=20_temp=0.6_top-p=0.9_ne=4_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=5_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
]

NP_files = [
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=1_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=2_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=3_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=4_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
    "stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=1_temp=0.6_top-p=0.9_np=5_ne=8_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3-CN=3750_preference.json",
]

import json
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, dpi=300, sharey=True, figsize=(10, 4.35))

def collect_sims(data: list[str]):
    sims = []
    for d in data:
        d = "./neurips/transfer_metrics/" + d
        sims.append(json.loads(open(d).read())["cisr_sim"])
    return sims

print(collect_sims(NE_files))
print(collect_sims(NP_files))

axes[0].plot([1, 2, 4, 8], collect_sims(NE_files), marker='o')
axes[0].set_xlabel("Number of Exemplars", fontsize=16)
axes[0].set_ylabel("Stylistic Similarity", fontsize=16)
axes[0].set_title("Ablation: Exemplars", fontsize=16)
axes[1].plot([1, 2, 3, 4, 5], collect_sims(NP_files), marker='o')
axes[1].set_xlabel("Number of Paraphrases", fontsize=16)
axes[1].set_ylabel("Stylistic Similarity", fontsize=16)
axes[1].set_title("Ablation: Paraphrases", fontsize=16)
plt.savefig("./ablations.pdf")