import json
import os
import glob
import matplotlib.pyplot as plt
import numpy as np

def get_results(base_path_name, patch_type, dataset, run=1, t=1, base=False):
    if base:
        path = os.path.join(base_path_name, patch_type)
    else:
        path = os.path.join(base_path_name, str(run), str(t), patch_type)

    name = os.listdir(path)[0]
    path = os.path.join(path, name)
    path = glob.glob(os.path.join(path, "results*.json"))[0]

    with open(path, "r") as f:
        patch_json = json.load(f)
        if "math_genie" in dataset:
            return patch_json["results"]["math_genie"]["bleu,none"]
        else:
            return patch_json["results"][dataset]["acc,none"]
    
datasets = ["boolq", "winogrande", "math_genie", "arc_easy"]

for dataset in datasets:
    print(dataset)
    patch_runs = []
    no_patch_runs = []
    for run in range(1,6):
        inner = "merged"
        name = dataset

        base_patch = "eval_confirm/mistral/" + dataset
        base_no_patch = "eval_confirm/mistral/" + name
        
        try:
            patch = [get_results(base_patch, "merged", name, base=True)]
            no_patch = [get_results(base_no_patch, "base", name, base=True)]
            for t in range(1, 13):
                patch.append(get_results(base_patch, inner, name, run, t))
                no_patch.append(get_results(base_no_patch, "base", name, run, t))

            patch_runs.append(patch)
            no_patch_runs.append(no_patch)
            plt.plot(range(len(patch)), patch, label="Patch " + str(run))
            plt.plot(range(len(patch)), no_patch, label="No Patch " + str(run))
        except Exception as e:
            print(e)
            pass

    plt.legend()
    plt.title(dataset + " Performance")
    plt.xticks(range(13))
    plt.xlabel("t")
    plt.grid(True, linestyle='--', alpha=0.5)

    if "math_genie" in dataset:
        plt.ylabel("Bleu")
    else:
        plt.ylabel("Accuracy")
    plt.savefig(f"plot_{dataset}.png")
    plt.clf()

    print(patch_runs)
    print(no_patch_runs)
    print([len(e) for e in patch_runs])
    patch_arr = np.array(patch_runs)
    patch_means = patch_arr.mean(axis=0)
    patch_stds = patch_arr.std(axis=0)
    print("Patch")
    print(patch_means)
    print(patch_stds)

    print(no_patch_runs)
    no_patch_arr = np.array(no_patch_runs)
    no_patch_means = no_patch_arr.mean(axis=0)
    no_patch_stds = no_patch_arr.std(axis=0)
    print("No Patch")
    print(no_patch_means)
    print(no_patch_stds)
    print("-"*20)

    
    plt.title(dataset + " Performance")
    plt.xticks(range(13))
    plt.xlabel("t")
    plt.grid(True, linestyle='--', alpha=0.5)
    if "math_genie" in dataset:
        plt.ylabel("BLEU")
    else:
        plt.ylabel("Accuracy")
    plt.errorbar(range(patch_means.shape[0]), patch_means, patch_stds, fmt='-o', capsize=4, elinewidth=1.5, markeredgewidth=1.5, label="Patch")
    plt.errorbar(range(patch_means.shape[0]), no_patch_means, no_patch_stds, fmt='-o', capsize=4, elinewidth=1.5, markeredgewidth=1.5, label="No Patch")
    plt.legend()
    plt.savefig(f"mean_plot_{dataset}.png")
    plt.clf()
