
import json
import os
from glob import glob

import fire
import matplotlib.pyplot as plt
import pandas as pd

def plot_vary_generations(
    outdir: str,
    metrics_dir: str = "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9.jsonl",
):
    curves = {
        "1.0": {},
        "2.0": {},
        "3.0": {},
        "4.0": {},
        "5.0": {},
    }
    N = [20, 40, 60, 80, 100]

    for key in curves.keys():
        curves[key]["cisr_sim"] = []
        curves[key]["luar_sim"] = []
        curves[key]["sbert_sim"] = []
        filenames = glob(os.path.join(metrics_dir, f"*style={key}*"))

        for n in N:
            for fname in filenames:
                if f"N={n}" not in os.path.basename(fname):
                    continue
                data = json.loads(open(fname).read())
                curves[key]["cisr_sim"].append(data["cisr_sim"])
                curves[key]["luar_sim"].append(data["luar_sim"])
                curves[key]["sbert_sim"].append(data["sbert_sim"])

    # Do this with different subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    for key in curves.keys():
        ax1.plot(N, curves[key]["cisr_sim"], label=f"W={key}")
        ax2.plot(N, curves[key]["sbert_sim"], label=f"W={key}")
    ax1.set_ylabel('CISR Similarity')
    ax2.set_ylabel('SBERT Similarity')
    ax1.legend(loc='lower center', fancybox=True, ncol=3)
    ax1.set_xticks(N)
    ax2.set_xticks(N)
    ax1.set_xlabel('Number of Generations')
    ax2.set_xlabel('Number of Generations')
    plt.suptitle("Similarity Metrics")
    save_path = os.path.join(outdir, "vary_num_generations.pdf")
    plt.savefig(save_path)
    plt.close()

def plot_vary_num_paraphrases(
    outdir: str,
):
    label_to_dirnames = {
        "Target Only": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_target-only.jsonl/",
        "NP=1": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=1.jsonl/",
        "NP=2": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=2.jsonl/",
        "NP=3": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=3.jsonl/",
        "NP=4": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=4.jsonl/",
        "NP=5": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9.jsonl/",
    }
    fname = "average_style=4.0_content=1.0_N=100.json"
    
    X = label_to_dirnames.keys()
    CISR = []
    CRUD = []
    # SD = []
    SBERT = []
    for label, dirname in label_to_dirnames.items():
        filepath = os.path.join(dirname, fname)
        data = json.loads(open(filepath).read())
        CISR.append(data["cisr_sim"])
        CRUD.append(data["luar_sim"])
        # SD.append(data["sd_sim"])
        SBERT.append(data["sbert_sim"])
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.plot(X, CISR, label="CISR")
    ax1.plot(X, CRUD, label="CRUD")
    # ax1.plot(X, SD, label="StyleDistance")
    ax1.set_xlabel("Prompt Type")
    ax1.set_ylabel("Style Similarity")
    ax1.legend(loc='lower center', fancybox=True, ncol=3)
    
    ax2.plot(X, SBERT, label="SBERT")
    ax2.set_ylabel("SBERT Similarity")
    plt.suptitle("Performance of Various Prompt Types - NG=100, SW=4")
    save_path = os.path.join(outdir, "prompt_type.pdf")
    plt.savefig(save_path)
    plt.close()    
    
def table(
    outdir: str,
    clean: bool = True,
    print_table: bool = False,
):
    label_to_dirnames = {
        "Target Only": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_target-only.jsonl/",
        "NP=1": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=1.jsonl/",
        "NP=2": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=2.jsonl/",
        "NP=3": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=3.jsonl/",
        "NP=4": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_np=4.jsonl/",
        "NP=5": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9.jsonl/",
        "Trained (small)": "./metrics/stylemc_paraphrase_transfer_N=100_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=True.jsonl", 
    }
    
    records = []
    fname = "average_style={}_content=1.0"
    fname += "_clean={}_N=100.json".format(clean)
        

    for label, dirname in label_to_dirnames.items():
        for weight in [1.0, 2.0, 3.0, 4.0, 5.0]:
            curr_fname = fname.format(weight)
            curr_fname = os.path.join(dirname, curr_fname)
            data = json.loads(open(curr_fname).read())
            if "sd_sim" in data:
                data.pop("sd_sim")
            data["Name"] = label + " Style Weight={}".format(weight)
            records.append(data)
    
    records.append({
        "Name": "StyleMC (GPT-Revise)",
        "luar_sim": 0.791,
        "cisr_sim": 0.802,
        "sbert_sim": .728,
    })
    
    df = pd.DataFrame(data=records)
    df.set_index("Name", inplace=True)
    df["avg"] = (df["cisr_sim"] + df["sbert_sim"]) / 2
    # df.sort_values(by=["cisr_sim", "sbert_sim"], axis=0, inplace=True)
    df.sort_values(by=["avg"], axis=0, inplace=True)
    table = df.to_markdown()
    if print_table:
        print(table)
    with open(os.path.join(outdir, f"table_clean={clean}.md"), "w") as f:
        f.write(table)

def main(
    outdir: str = "./figures"
):
    os.makedirs(outdir, exist_ok=True)
    # plot_vary_generations(outdir)
    # plot_vary_num_paraphrases(outdir)
    table(outdir, clean=False, print_table=False)
    table(outdir, clean=True, print_table=False)

    return 0

if __name__ == "__main__":
    fire.Fire(main)