import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.plots import set_plt

set_plt()

names = [
    "coco_images_with_captions",
    "tm_images_with_prompts",
    "gen_images_with_prompts",
    "generated_tm_images_with_prompts",
    "vm_images_with_prompts",
]
plot_names = ["COCO2014", "TM", "Generated Images", "Generated TM", "VM"]


def main():
    dfs = []

    for label, name in zip(plot_names, names):
        data = np.load(f"results/{name}.npz", allow_pickle=True)
        df = pd.DataFrame.from_dict(
            {
                "Step": data["saved_steps"].flatten(),
                "SSCD$_{orig}$": data["sscd_scores"].flatten(),
                "L2 Norm": data["norms"].flatten(),
            }
        )
        df["Set"] = label
        df.Step = df.Step.astype(str)
        dfs.append(df)

    df = pd.concat(dfs)

    fig, axs = plt.subplots(1, 2, figsize=(20, 5))

    ax = axs[0]
    sns.lineplot(df, x="Step", y="SSCD$_{orig}$", hue="Set", ax=ax)
    ax.axhline(y=0.7, color="black", linestyle=":", label="Memorization Threshold")

    ax = axs[1]
    sns.lineplot(df, x="Step", y="L2 Norm", hue="Set", ax=ax)

    plt.savefig("plots/out/yadv_differences.pdf", bbox_inches="tight")


if __name__ == "__main__":
    main()
