import collections
import os
import pdb
import json
import glob
import pickle
from typing import List,Dict,Tuple,Optional, Any
from pathlib import Path
import pprint


from git.objects.util import get_object_type_by_name
from matplotlib.pyplot import xlabel
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity

from PIL import Image
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns

from src.utils import get_objects, get_box_ids, is_object, is_box_id, get_token_pos_given_span_types, get_object_mapping
from src.dataset import PhraseProbeDataLoader

INPUT_PATH_PERIOD_COMMA = "entity-tracking-probing/model_cache/boxes_data_altForm_default/codellama13b_period_comma/test_inputs.txt"
INPUT_PATH_PERIOD_COMMA_PRIOR = "entity-tracking-probing/model_cache/boxes_data_altForm_default/codellama13b_period_comma_prior/test_inputs.txt"
INPUT_PATH_NUMBER_ALL = "entity-tracking-probing/model_cache/boxes_data_altForm_default/codellama13b_number_all/test_inputs_subset.txt"
OTI, ITO = get_object_mapping("entity-tracking-probing/data/llama_friendly_objects.csv")
OTI_BOX = lambda obj, box_id: int(box_id)*100 + OTI[obj]
MASK_FIELDS = ["local_box", "local_obj", "local_box_obj", "cum_box", "cum_obj", "cum_box_obj"]

def fix_fonts(title=20, label=20, xtick=15, ytick=15, default=15):
    # Set the global font family to 'Times New Roman'
    # keep running into
    plt.rc('font', family='serif', serif=['Times New Roman'])

    # Set the global default font size (e.g., to 14)
    plt.rcParams["font.size"] = default
    plt.rcParams["xtick.labelsize"] = xtick  # Optional: specific size for x-axis ticks
    plt.rcParams["ytick.labelsize"] = ytick  # Optional: specific size for y-axis ticks
    plt.rcParams["axes.labelsize"] = label  # Optional: specific size for axis labels
    plt.rcParams["axes.titlesize"] = title  # Optional: specific size for plot titles


def get_input_texts(input_data_path:str=INPUT_PATH_PERIOD_COMMA):
    texts = []
    with open(input_data_path) as f:
        texts = [line.strip() for line in f]
    return texts


def get_data_masks(texts: List[str]):
    mask_tensors = PhraseProbeDataLoader.get_data_masks(texts, OTI_BOX)
    return mask_tensors, texts


def add_results(res_dir: str, masks) -> pd.DataFrame:
    pred = torch.from_numpy(np.loadtxt(f"{res_dir}/predictions.txt", skiprows=1))
    label = masks["local_label"] if "local" in res_dir else masks["cum_label"]
    full_acc = pred == label
    output= {}
    # first let's just calculate accuracy
    for mask_field in MASK_FIELDS:
        mask = masks[mask_field]
        mask_pred = torch.masked_select(pred, mask.bool())
        mask_label = torch.masked_select(label, mask.bool())
        matrix = confusion_matrix(mask_label, mask_pred)
        acc = (full_acc * mask).sum() / mask.sum()
        output[f"accuracy_{mask_field}"] = acc.item()
        output[f"cm_{mask_field}"] = matrix
    # pdb.set_trace()
    # then calculate all values of Prediction matrix (3 X 3) combined with each of the masks
    return output



def plot_probe_results(result_dir:str, out_dir:str, exp_name:str="phrase_probe_accuracy"):
    df, strat_df = [], []
    # texts = get_input_texts()
    # masks, texts = get_data_masks(texts)

    for exp_dir in tqdm(os.listdir(result_dir)):
        layer = exp_dir.split("_")[0].replace("layer", "")
        try:
            result = json.load(open(os.path.join(result_dir, exp_dir, "tensorboard.txt")))
        except Exception as e:
            continue
        result["layer"] = int(layer)
        result["last_epoch_test_acc"] = result["test_acc_cont"][-1]

        for field in result['train_acc_mask_cont'].keys():
            # result[f"last_epoch_train_{field}"] = result["train_acc_mask_cont"][field][-1]
            result[field] = result["test_acc_mask_cont"][field][-1]
        # result.update(add_results(os.path.join(result_dir, exp_dir), masks))
        df.append(result)
        for i in range(len(result["test_strat_acc_cont"][-1])):
            strat_df.append({
                "layer": int(layer),
                "num_ops_before": i,
                "accuracy": result["test_strat_acc_cont"][-1][i],
            })

    df = pd.DataFrame(df)
    # pdb.set_trace()
    strat_df = pd.DataFrame(strat_df)

    model_name = result_dir.split("/")[-3]
    dataset_name = result_dir.split("/")[-4]
    tag_type = '_'.join(result_dir.split("_")[-3:])
    out_exp_dir = f"{out_dir}/{exp_name}/{dataset_name}/{model_name}"
    os.makedirs(out_exp_dir, exist_ok=True)

    # # plot end of training test acc across layers
    # sns.lineplot(x="layer", y="last_epoch_test_acc", data=df)
    # plt.savefig(f"{out_exp_dir}/test_acc_{tag_type}.png")
    #
    # # plot stratified results (technically should plt.close(), but having avg on this plot is not bad)
    # sns.lineplot(x="layer", y="accuracy", hue="num_ops_before", data=strat_df)
    # plt.savefig(f"{out_exp_dir}/test_acc_strat_{tag_type}.png")
    # plt.close()
    #
    # # plot end of training test acc across layers
    # sns.lineplot(x="layer", y="last_epoch_test_acc_nontriv", data=df)
    # plt.savefig(f"{out_exp_dir}/test_acc_nontriv_{tag_type}.png")
    # plt.close()

    # plot different masks
    value_fields = ["local_box", "local_obj", "local_box_obj"]
    if "cumulative" in result_dir:
        value_fields = ["local_box", "local_obj", "local_box_obj", "cum_box_obj"]
    df_melt = pd.melt(df, id_vars=["layer"],
                      value_vars=value_fields,
                      var_name="Mask Types", value_name="Accuracy")
    m = {
        "local_box": "Local Box",
        "local_obj": "Local Object",
        "local_box_obj": "Local Box-Object",
        "cum_box_obj": "Cumulative Box-Object",
    }
    df_melt["Mask Types"] = df_melt["Mask Types"].map(m)
    fix_fonts(label=20, title=20, xtick=15, ytick=15)
    sns.lineplot(x="layer", y="Accuracy", hue="Mask Types",data=df_melt)
    plt.ylim(0, 1.05)
    plt.title(f"Conditioning on {'Object' if 'object' in tag_type else 'Box ID'}")
    plt.ylabel("Mean Probe Accuracy")
    plt.xlabel("Layer")
    plt.tight_layout()
    plt.savefig(f"{out_exp_dir}/test_acc_different_masks_{tag_type}.png")
    plt.savefig(f"{out_exp_dir}/test_acc_different_masks_{tag_type}.pdf", dpi=600)
    plt.close()

    df_melt["Layer Normalized"] = df_melt["layer"]/df_melt["layer"].max()
    df_melt["Model"] = model_name
    return df_melt


def plot_probe_results_multi_model(result_dirs: List[str], out_dir: str):
    dfs = []
    for result_dir in result_dirs:
        df = plot_probe_results(result_dir, out_dir)
        dfs.append(df)

    all_df = pd.concat(dfs)
    sns.lineplot(x="Layer Normalized", y="Accuracy", hue="Mask Types", style="Model", data=all_df)
    plt.savefig(f"{out_exp_dir}/test_acc_different_masks_{tag_type}.png")
    plt.title(f"Conditioning on {'Object' if 'object' in tag_type else 'Box ID'}")
    plt.tight_layout()
    plt.close()


def get_last_op(sentence:str) -> str:
    last_phrase = sentence.strip().split(",")[-1].split(".")[-1]
    op = last_phrase.strip().split(" ")[0]
    if op not in ["Remove", "Move", "Put"]:
        op = "Description"
    return op.lower()


def get_last_operant(sentence:str) -> int:
    last_phrase = sentence.strip().split(",")[-1].split(".")[-1]
    return last_phrase.count(" and ") + 1

def get_obj_idx_in_phrase(sentence: str, box_obj_idx: int) -> Tuple[int,int]:
    last_phrase = sentence.strip().split(",")[-1].split(".")[-1]
    objs = get_objects(last_phrase)
    for i, obj in enumerate(objs):
        if OTI[obj] == box_obj_idx%100:
            return i, len(objs)
    return -100, len(objs)

def correlate_one_epoch_errors_across_features(pred_path:str, out_dir:str, cache_path:str) -> None:
    pred = torch.from_numpy(np.load(pred_path))
    masks = torch.load(glob.glob(cache_path.rsplit("/", 1)[0]+"/exploded_*_all_local_test_subset.pkl")[0])
    with open(cache_path, "rb") as f:
        output = pickle.load(f)
    activation_indices, y, num_ops, counts, all_mentioned_objects, texts = output
    del activation_indices, output, counts, all_mentioned_objects
    global_numops = np.array([text.count(",") + text.count(".") for text in texts])
    num_operants = np.array([get_last_operant(text) for text in texts])
    y = np.stack([t.numpy() for t in y])
    # errors when groundtruth is remove
    local_gt_remove_incorrect_indices = ((y == 2) & (pred != 2)).nonzero()
    
    # look at stats on operation type
    local_ops = np.array([get_last_op(text) for text in texts])
    counter = collections.Counter(local_ops[local_gt_remove_incorrect_indices[:,0]])
    print(counter)

    pdb.set_trace()
    # for the move out cases, is the move-in prediction correct
    # move_indices = local_gt_remove_incorrect_indices[local_ops[local_gt_remove_incorrect_indices[:,0]]=="move"]
    # TODO

    # look at index of obj in the remove phrase
    obj_idx_in_phrase_list = [get_obj_idx_in_phrase(texts[i], box_obj_idx) for (i, box_obj_idx) in local_gt_remove_incorrect_indices]
    obj_idx_in_phrase_list = np.array(obj_idx_in_phrase_list)
    only_obj_cnt = (obj_idx_in_phrase_list[:,1]==1).sum()
    multi_obj_cnt = (obj_idx_in_phrase_list[:, 1]>1).sum()
    desc = pd.Series(obj_idx_in_phrase_list[obj_idx_in_phrase_list[:, 1] > 1, 0]).describe()
    print(f"{only_obj_cnt=}, {multi_obj_cnt=}\n{desc}")  # pretty evenly split between 1st or 2nd obj
    
    # look at operation idx
    ops = np.array([text.count(",")+text.count(".") for text in texts])
    print(pd.Series(ops[local_gt_remove_incorrect_indices[:,0]]).describe())

    dataset_name = pred_path.split("/")[-6]
    model_name = pred_path.split("/")[-5]
    tag_type = pred_path.split("/")[-3]

    # look at correlation to training frequency
    train_masks = torch.load(INPUT_PATH_NUMBER_ALL.replace("inputs","masks").replace(".txt", ".pt").replace("test", "train"))
    rare_label_frequency = (train_masks["local_label"]!=0).sum(0).tolist()
    all_accuracy_by_probe = (y==pred).float().mean(0).tolist()
    rare_accuracy_by_probe = (((y==pred) * masks["local_box_obj"]).sum(0) / masks["local_box_obj"].sum(0)).tolist()
    sns.regplot(x=rare_label_frequency, y=all_accuracy_by_probe)
    sns.regplot(x=rare_label_frequency, y=rare_accuracy_by_probe)
    plt.legend(["all accuracy", "", "", "exist/remove accuracy", "", ""])
    plt.xlabel("Exist / Remove Label Frequency in Training")
    plt.ylabel("Accuracy in Test")
    plt.savefig(f"{out_dir}/phrase_probe_accuracy/{dataset_name}/{model_name}/rare_label_frequency_vs_acc_{tag_type}.png")
    plt.close()

    # hue by box id
    by_box_hue = torch.arange(7).repeat((100,1)).T.reshape(-1)
    sns.scatterplot(x=rare_label_frequency, y=rare_accuracy_by_probe, hue=by_box_hue)
    for i in range(700):   # add label for obj name
        plt.text(x=rare_label_frequency[i] + 0.01,  # Adjust x-position for better visibility
                 y=rare_accuracy_by_probe[i] + 0.01,  # Adjust y-position for better visibility
                 s=ITO[i%100],
                 fontdict=dict(color='black', size=3))
    plt.legend([f"Box {i}" for i in range(7)])
    plt.xlabel("Exist / Remove Label Frequency in Training")
    plt.ylabel("Accuracy in Test")
    plt.savefig(f"{out_dir}/phrase_probe_accuracy/{dataset_name}/{model_name}/rare_label_frequency_vs_acc_{tag_type}_by_box.png",dpi=600)
    plt.close()

    # pdb.set_trace()
    return


def plot_intervention_results(result_dir: str, output_dir: str, hue="intervention_layer_type", ax=None, **filter_kwargs):
    all_df = []
    for exp_path in tqdm(os.listdir(result_dir)):
        try:
            df = pd.read_json(os.path.join(result_dir, exp_path), orient="records", lines=True)
        except:
            print(f"weird parsing error: {os.path.join(result_dir, exp_path)}")
            continue

        # backwards compatibility
        if 'intervention_alpha' not in df.columns:
            alpha = 1.0 if 'alpha=' not in exp_path else float(exp_path.split('alpha=')[1].split('_')[0])
            df['intervention_alpha'] = alpha
        if "at_n" not in df.columns:
            at_n = int(exp_path.split("_n=")[1].replace(".jsonl", "")) if "_n=" in exp_path else None
            df["at_n"] = at_n
        if "intervention_layer_type" not in df.columns:
            if not df.at_n[0] is None:
                df["layers"] = df.at_n
                df["intervention_layer_type"] = "at_n"
            elif not df.last_n[0] is None:
                df["layers"] = df.last_n
                df["intervention_layer_type"] = "last_n"
            elif not df.fisrt_n[0] is None:
                df["layers"] = df.fisrt_n
                df["intervention_layer_type"] = "first_n"
        if "intervention_probe_class" not in df.columns:
            df["intervention_probe_class"] = 2
        all_df.append(df)
    all_df = pd.concat(all_df)
    id_cols = ["layers", "intervention_layer_type", "intervention_probe_class", "intervention_direction", "intervention_operation", "intervention_alpha"]
    metric_cols = ["intervention_obj_success", "intervention_rest_success"]
    suffix_str = ""
    if filter_kwargs:
        # pdb.set_trace()
        for k, v in filter_kwargs.items():
            all_df = all_df[all_df[k].apply(lambda val: val == v)]
            suffix_str += f"_{k}={v}"
    # agg_df = all_df.groupby(id_cols)[metric_cols].mean()
    model_name = result_dir.split("/")[-1]
    dataset_name = result_dir.split("/")[-3]
    out_exp_dir = f"{output_dir}/phrase_probe_intervention_accuracy/{dataset_name}/{model_name}"
    os.makedirs(out_exp_dir, exist_ok=True)
    df_melt = pd.melt(
        all_df, id_vars=id_cols,
        value_vars=metric_cols,
        var_name="Success Type",
        value_name="Success"
    )
    df_melt["Success Type"] = df_melt["Success Type"].map({"intervention_obj_success":"Target Object", "intervention_rest_success":"Original Object(s)"})
    hue_order = sorted(df_melt[hue].unique().tolist())
    fix_fonts(title=15, label=15, default=12)
    if hue != "Success Type":
        sns.lineplot(x="layers", y="Success", hue=hue, hue_order=hue_order, data=df_melt, style="Success Type",ax=ax)
    else:
        sns.lineplot(x="layers", y="Success", hue=hue, hue_order=hue_order, data=df_melt, ax=ax)
    if ax is not None:
        return
    plt.ylim(0, 1.05)
    # plt.title(f"{'_'.join(dataset_name.split('_')[-2:])},{model_name},{','.join([str(v) for v in filter_kwargs.values()])}")
    # title = "Box ID Probe" if "number_all" in result_dir else "Object Probe"
    # title += ", " + "1-Put" if "put" in dataset_name else "1-Remove" if "remove" in dataset_name else "Move (out)" if "1move_moreObj" in dataset_name else "Move (in)"
    # plt.title(title)
    plt.title(f"1-REMOVE" if "remove" in dataset_name else "1-PUT" if "put" in dataset_name else "1-MOVE-OUT" if "moreObj" in dataset_name else "1-MOVE-IN")
    plt.savefig(os.path.join(out_exp_dir, f"success{suffix_str}.png"))
    plt.savefig(os.path.join(out_exp_dir, f"success{suffix_str}.pdf"), dpi=600)
    plt.close()
    pass


def plot_pca_2d(X, label_vec, out_path:str):
    # Perform PCA
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X)

    # Plot
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=label_vec, cmap="tab10", s=20)

    plt.title("PCA Projection (2D)")
    plt.xlabel("PC1")
    plt.ylabel("PC2")

    # optional legend if label_vec contains discrete categories
    if len(np.unique(label_vec)) < 20:
        handles, _ = scatter.legend_elements()
        plt.legend(handles, np.unique(label_vec), title="Labels", bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_cosine_sim(X, out_path:str, **kwargs):
    cosine_sim = cosine_similarity(X)
    plt.imshow(cosine_sim, interpolation='nearest')
    plt.colorbar()
    plt.title("Probe Cosine Similarity Matrix")
    plt.savefig(out_path, **kwargs)
    plt.close()

def plot_ranks(X, out_path:str):
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    total_var = np.sum(S ** 2)
    explained = np.cumsum(S ** 2) / total_var
    plt.plot(explained)

    essential_rank = np.searchsorted(explained, 0.95) + 1
    print("Essential rank:", essential_rank)
    plt.xlabel("Rank / Singular Component")
    plt.ylabel("Cumulative variance explained")
    plt.title(f"Rank / Shared Directions (essential rank={essential_rank})")
    plt.savefig(out_path)
    plt.close()

    k = essential_rank
    X_ess = (X.numpy() @ Vt[:k].T.numpy()) @ Vt[:k].numpy()
    X_resid = X.numpy() - X_ess
    ess_energy = np.sum(X_ess ** 2, 1)
    resid_energy = np.sum(X_resid ** 2, 1)
    frac_resid = resid_energy / (ess_energy + resid_energy)

    plt.hist(frac_resid, bins=30)
    plt.title("Fraction of probe magnitude in non-essential subspace")
    plt.xlabel("Residual (non-essential) energy")
    plt.ylabel("Probe count")
    plt.savefig(out_path.replace(".png", "_2.png"))
    plt.close()


def get_cosine_sim(A, B):
    dot_products = torch.sum(A * B, dim=1)
    norms_A = torch.norm(A, dim=1)
    norms_B = torch.norm(B, dim=1)
    cosine_sim = dot_products / (norms_A * norms_B)
    return cosine_sim

def plot_cosine_sim_distribution(A, B, out_path:str):
    cosine_sim = get_cosine_sim(A, B)

    # Plot histogram
    plt.hist(cosine_sim, bins=30)
    plt.title("Histogram of Cosine Similarities")
    plt.xlabel("Cosine similarity")
    plt.ylabel("Frequency")
    plt.savefig(out_path)
    plt.close()


def plot_cosine_sim_distribution_across_layer(sim_ne_ex, sim_ex_re, sim_ne_re, out_path:str):
    # pdb.set_trace()
    n_layers = len(sim_ne_ex)
    layers = np.arange(n_layers)

    # Build tidy dataframe for seaborn
    all_df = []
    for i in range(len(sim_ne_ex[0])):
        df = pd.DataFrame({
            "layer": np.concatenate([layers, layers, layers]),
            "cosine_similarity": np.concatenate([np.array(sim_ne_ex)[:,i], np.array(sim_ne_re)[:,i], np.array(sim_ex_re)[:,i]]),
            "pair": (
                    ["Non-Exist vs. Exist"] * n_layers +
                    ["Non-Exist vs. Removed"] * n_layers +
                    ["Exist vs. Removed"] * n_layers
            )
        })
        all_df.append(df)
    all_df = pd.concat(all_df)
    # --------------------------- Plot ---------------------------
    fix_fonts()
    sns.lineplot(
        data=all_df,
        x="layer",
        y="cosine_similarity",
        hue="pair",
        estimator="mean",
        errorbar="sd"  # shaded standard-deviation band
    )
    # title = "Pairwise Cosine Similarity: " + ("Box ID" if "number_all" in out_path.split("/")[-1] else "Object") + "Probe"
    # title = "Box ID" if "number_all" in out_path.split("/")[-1] else "Object"
    # plt.title(title)
    plt.xlabel("Layer")
    plt.ylabel("Cosine similarity")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.savefig(out_path.replace(".png", ".pdf"),dpi=600)
    plt.close()


def probe_similarities_across_class_and_layers(result_dir: str, output_dir: str, per_layer_plot:bool=True):
    probe_class_labels = torch.arange(3).repeat(1, 700).T.reshape(-1)
    box_id_labels = torch.arange(7).repeat(100,1).T.reshape(-1)
    obj_labels = torch.arange(100).repeat(1,7).T.reshape(-1)
    layer_dirs = os.listdir(result_dir)
    sim_ne_ex = [0 for _ in range(len(layer_dirs))]
    sim_ne_re = [0 for _ in range(len(layer_dirs))]
    sim_ex_re = [0 for _ in range(len(layer_dirs))]
    for layer_path in tqdm(layer_dirs):
        weights = torch.load(os.path.join(result_dir, layer_path, "checkpoint.ckpt"), map_location="cpu")['proj.weight']
        h_dim = weights.shape[1]
        layer = int(layer_path.split("_")[0].replace("layer",""))
        # also let's save it across layers and plot
        sim_ne_ex[layer - 1] = get_cosine_sim(weights.reshape(700, 3, h_dim)[:, 1],weights.reshape(700, 3, h_dim)[:, 0])
        sim_ex_re[layer - 1] = get_cosine_sim(weights.reshape(700, 3, h_dim)[:, 1],weights.reshape(700, 3, h_dim)[:, 2])
        sim_ne_re[layer - 1] = get_cosine_sim(weights.reshape(700, 3, h_dim)[:, 0],weights.reshape(700, 3, h_dim)[:, 2])

        if per_layer_plot:
            # seems like largest structure is the ternary classes (3 clusters)
            plot_pca_2d(weights, probe_class_labels, os.path.join(result_dir, layer_path, "pca_3class.png"))

            # then within one class (say remove), what's the structure like (not a ton of patterns)
            plot_pca_2d(weights.reshape(700, 3, h_dim)[:,2], box_id_labels, os.path.join(result_dir, layer_path, "pca_1class_7box.png"))
            plot_pca_2d(weights.reshape(700, 3, h_dim)[:, 2], obj_labels, os.path.join(result_dir, layer_path, "pca_1class_100obj.png"))

            # let's look at cosine similarity pairwise within remove class (early layer show structure later show nothing outside of diagonal)
            plot_cosine_sim(weights.reshape(700, 3, h_dim)[:, 2], os.path.join(result_dir, layer_path, "cos_1class.png"), dpi=600)
            plot_cosine_sim(weights, os.path.join(result_dir, layer_path, "cos_3class.png"), dpi=600)
            # let's see out of 700 probes between exist and remove class, how similar are they
            plot_cosine_sim_distribution(weights.reshape(700, 3, h_dim)[:, 1], weights.reshape(700, 3, h_dim)[:, 2], os.path.join(result_dir, layer_path, "remove_exist_cos_hist.png" ))
            plot_cosine_sim_distribution(weights.reshape(700, 3, h_dim)[:, 1], weights.reshape(700, 3, h_dim)[:, 0], os.path.join(result_dir, layer_path, "exist_non_exist_cos_hist.png"))
            plot_cosine_sim_distribution(weights.reshape(700, 3, h_dim)[:, 2], weights.reshape(700, 3, h_dim)[:, 0], os.path.join(result_dir, layer_path, "remove_non_exist_cos_hist.png"))

            # now look at the rank of these probes not very informative
            plot_ranks(weights.reshape(700, 3, h_dim)[:, 2], os.path.join(result_dir, layer_path, "rank_1class.png") )

    model_name = result_dir.split("/")[-3]
    dataset_name = result_dir.split("/")[-4]
    probe_type = result_dir.split("/")[-1]
    out_exp_dir = f"{output_dir}/phrase_probe_accuracy/{dataset_name}/{model_name}"
    os.makedirs(out_exp_dir, exist_ok=True)
    plot_cosine_sim_distribution_across_layer(sim_ne_ex, sim_ex_re, sim_ne_re, os.path.join(out_exp_dir, f"inter_class_cos_across_layers_{probe_type}.png"))

    # aggregate plots to a giant plot
    # aggeregate_plots(
    #     plots=["pca_3class","pca_1class_7box", "pca_1class_100obj","cos_1class"],
    #     layers=[1,3,5,10,20],
    #     layers_dir=result_dir,
    #     out_path=os.path.join(out_exp_dir, f"aggregated_viz_{probe_type}.png")
    # )
    plt.close()
    return


def aggeregate_plots(plots: List[str], layers:List[int], out_path: str, layers_dir: str):
    # Set up figure grid
    n_rows = len(plots)
    n_cols = len(layers)

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 3 * n_rows),
        squeeze=False
    )

    for r, plot_prefix in enumerate(plots):
        for c, layer in enumerate(layers):
            ax = axes[r][c]
            path = f"{layers_dir}/layer{layer}_token1/{plot_prefix}.png"
            if not os.path.exists(path):
                ax.set_title(f"No file\n{plot_prefix} @ layer {layer}")
                ax.axis("off")
                continue

            img = Image.open(path)

            ax.imshow(img)
            ax.set_title(f"{plot_prefix}\nLayer {layer}")
            ax.axis("off")

    plt.tight_layout()
    plt.savefig(out_path)
    return


def aggeregate_intervention_plots(model:str, out_path: str):
    os.makedirs(f"{out_path}/{model}", exist_ok=True)
    intervention_paths = [
        # 1 put
        "<anonymized>"
        # 1 move in
        "<anonymized>"
        # 1 remove
        "<anonymized>"
        # 1 move out
        "<anonymized>"
    ]
    ops = ["1-Put", "1-Move(in)", "1-Remove", "1-Move(out)"]
    # Set up figure grid
    n_rows = 2  # number / object
    n_cols = 4  # put / move-in / remove / move-out
    fix_fonts()
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 2 * n_rows),
        squeeze=True,
        sharex = True, sharey = True,
    )
    # Add common X and Y labels for the entire figure
    fig.supylabel('Success Intervening at', fontsize=20)

    for col, path in enumerate(intervention_paths):
        for row, site in enumerate(["object-op", "number-op"]):
            ax = axes[row][col]
            args = {
                "hue":"Success Type",
                "intervention_operation":"query-remove" if ops[col] in ["1-Remove", "1-Move(out)"] else "query-exist",
                "intervention_direction":"null",
                "intervention_alpha":1,
                "intervention_site":site,
                "intervention_layer_type":"at_n"
            }
            plot_intervention_results(path, out_dir, ax=ax, **args)
            ax.set_ylim(0, 1.05)
            if row == 0:
                ax.set_title(ops[col], fontsize=25)
            elif row == 1:
                ax.set_xlabel("Layer")
            if col == 0:
                ax.set_ylabel("Box ID" if row == 1 else "Object")
            if (row, col) != (1,3):
                ax.get_legend().remove()

    plt.tight_layout()
    plt.savefig(f"{out_path}/{model}/intervention_aggregate.png")
    plt.savefig(f"{out_path}/{model}/intervention_aggregate.pdf", dpi=600)
    return


def get_confusion_matrix_by_operation(pred_path: str, cache_path: str, out_path: str):
    """
    Analyze confusion matrix by operation, and whether model prediction on those examples were correct
    """
    from matplotlib.colors import LogNorm
    plt.rc('font', family='serif', serif=['Times New Roman'])
    # load groundtruth label and text strings
    with open(cache_path, "rb") as f:
        output = pickle.load(f)
    activation_indices, y, num_ops, counts, all_mentioned_objects, texts = output
    del activation_indices, output, num_ops, counts, all_mentioned_objects
    local_ops = np.array([get_last_op(text) for text in texts])
    num_ops = np.array([text.count(",") + text.count(".") for text in texts])
    num_operants = np.array([get_last_operant(text) for text in texts])
    y = np.stack([t.numpy() for t in y])

    # load probe predictions
    pred = np.load(pred_path)  # [15K, 700]
    # masks = torch.load(mask_cache_path)
    # load model inference results (actually doing entity tracking)
    # inf_result = pd.read_json(inference_result_path, orient='records', lines=True)

    num_classes = 3
    matrices = {}
    # overall confusion matrix
    matrices["all"] = np.zeros((num_classes,num_classes))
    for task_idx in range(pred.shape[1]):
        matrices["all"] += confusion_matrix(y[:, task_idx],pred[:, task_idx], labels=list(range(num_classes)))

    # conf mat where local operation is a particular operation
    for op in ["remove", "put", "description", "move"]:
        matrices[op] = np.zeros((num_classes, num_classes))
        remove_mask = (local_ops==op).nonzero()
        for task_idx in range(pred.shape[1]):
            matrices[op] += confusion_matrix(y[remove_mask, task_idx].squeeze(), pred[remove_mask, task_idx].squeeze(), labels=list(range(num_classes)))
        print(f"confusion matrix ({op}):{matrices[op]}")

    # plot confusion matrices
    fig, axes = plt.subplots(1, 4, figsize=(12, 3), sharey=True)
    class_labels = ["Does Not Exist", "Exist", "Removed"]
    for i, op in enumerate(["description", "put", "remove", "move"]):
        cm = matrices[op]
        ax = axes[i]
        sns.heatmap(cm, annot=True, cmap="Blues", cbar=False,
            xticklabels=class_labels,
            yticklabels=class_labels,
            norm=LogNorm(vmin=cm.min()+1e-20, vmax=cm.max()),
            ax=ax)
        ax.set_title(op.upper())
        ax.set_xlabel("Predicted")
        ax.tick_params(axis='x', rotation=30)
        if i == 0:
            ax.set_ylabel("True")
            ax.tick_params(axis='y', rotation=0)




    plt.tight_layout()
    # plt.show()
    model_name = pred_path.split("/")[-5]
    dataset_name = pred_path.split("/")[-6]
    probe_type = pred_path.split("/")[-3]
    plt.savefig(f"{out_path}/{dataset_name}/{model_name}/{probe_type}_confusion_matrix_by_operation.png")
    plt.savefig(f"{out_path}/{dataset_name}/{model_name}/{probe_type}_confusion_matrix_by_operation.pdf", dpi=600)

    # pdb.set_trace()
    # put is really weirdly low accuracy
    # exist_incorrect_indices = np.stack(((y == 1) & (pred!=y)).nonzero()).T
    # mask = np.isin(exist_incorrect_indices[:,0], (local_ops=="put").nonzero())
    # exist_incorrect_indices = exist_incorrect_indices[mask]
    # exist_correct_indices = np.stack(((y == 1) & (pred == y)).nonzero()).T
    # mask = np.isin(exist_correct_indices[:, 0], (local_ops == "put").nonzero())
    # exist_correct_indices = exist_correct_indices[mask]
    #
    # # sample 5 good ones 5 bad ones
    # for i in np.random.choice(range(len(exist_incorrect_indices)), size=5, replace=False):
    #     sample_idx, probe_idx = exist_incorrect_indices[i]
    #     print(f"wrong pred: {i}, prediction->{pred[sample_idx, probe_idx]}\n{texts[sample_idx]}\n")
    #
    # for i in np.random.choice(exist_correct_indices[:, 0], size=5, replace=False):
    #     print(f"good  pred: {i}\n{texts[i]}\n")
    #
    # # check numops difference between right/wrong predictions
    # print(f"Wrong   put numops: {num_ops[exist_incorrect_indices[:,0]].mean()} +- {num_ops[exist_incorrect_indices[:,0]].std()}")
    # print(f"Correct put numops: {num_ops[exist_correct_indices[:, 0]].mean()} +- {num_ops[exist_correct_indices[:, 0]].std()}")
    #
    # print(f"Wrong   put numops: {num_operants[exist_incorrect_indices[:, 0]].mean()} +- {num_operants[exist_incorrect_indices[:, 0]].std()}")
    # print(f"Correct put numops: {num_operants[exist_correct_indices[:, 0]].mean()} +- {num_operants[exist_correct_indices[:, 0]].std()}")

    # not sure if I see anything salient..
    # now try to see if these probe failure leads to failures in prediction
    # pdb.set_trace()
    pass


def is_new_datapoint(probe_context:str) -> bool:
    probe_context = "**".join(probe_context.split("**")[:2])
    if "Box**" in probe_context:  # number/box_id probe
        return "," not in probe_context
    else: # object probe
        return "and the" not in probe_context and "," not in probe_context

def is_last_operation(probe_context:str, pred_prefix:str, obj: str) -> bool:
    context = probe_context.replace("**", " ").replace(" ,", ",").replace(" .", ".").strip()
    if not pred_prefix.startswith(context):
        pdb.set_trace()
    rest_prefix = pred_prefix[len(context):]
    return obj not in rest_prefix


def get_correlation_probe_vs_pred_accuracy(pred_path: str, cache_path: str, inference_result_path: str):
    """
    Analyze confusion matrix by operation, and whether model prediction on those examples were correct
    """
    # load groundtruth label and text strings
    with open(cache_path, "rb") as f:
        output = pickle.load(f)
    activation_indices, y, num_ops, counts, all_mentioned_objects, texts = output
    del activation_indices, output, num_ops, counts, all_mentioned_objects
    local_ops = np.array([get_last_op(text) for text in texts])
    y = np.stack([t.numpy() for t in y])  # [15K, 700]

    # load probe predictions
    pred = np.load(pred_path)  # [15K, 700]
    # load model inference results (actually doing entity tracking)
    inf_result = pd.read_json(inference_result_path, orient='records', lines=True)

    num_classes = 3
    probe_type = "number" if "number_all" in pred_path else "object"

    matrices = {
        "put-tgt": np.zeros((num_classes, 2)),  # 2 classes represent {not-predicted, predicted}
        "put-global": np.zeros((num_classes, 2)),
        "description-tgt": np.zeros((num_classes, 2)),
        "description-global": np.zeros((num_classes, 2)),
        "remove-tgt": np.zeros((num_classes, 2)),
        "remove-global": np.zeros((num_classes, 2)),
        "move-tgt": np.zeros((num_classes, 2)),
        "move-global": np.zeros((num_classes, 2)),
        "move-src": np.zeros((num_classes, 2)),
    }
    # now try to see if these probe failure leads to failures in prediction
    data_bunch_idx = -7
    for pred_idx in tqdm(range(pred.shape[0])):
        probe_context = texts[pred_idx]
        if is_new_datapoint(probe_context):
            data_bunch_idx += 7
        if probe_type == "object":
            pred_obj = texts[pred_idx].split("**")[1]
            label_probe_indices = y[pred_idx].nonzero()[0]
            exist_box = [i//100 for i in label_probe_indices if y[pred_idx, i]==1]
            exist_box = exist_box[0] if len(exist_box)==1 else None
            remove_box = [i//100 for i in label_probe_indices if y[pred_idx, i]==2]
            remove_box = remove_box[0] if len(remove_box)==1 else None

            # here we only consider examples where probe would have causally contributed to the object-box
            # tag signal by considering the last interaction an object has with any box
            is_last_op = is_last_operation(probe_context, inf_result["prefix"][data_bunch_idx], pred_obj)
            pred_objs = [pred_obj]
        else: # probe_type == "number":
            pred_objs = get_objects(probe_context.split(",")[-1].split(".")[-1])
            pred_box = int(texts[pred_idx].split("**")[1])
            exist_box, remove_box = None, None
            if local_ops[pred_idx] in ["put", "description"]:
                exist_box = pred_box
            elif local_ops[pred_idx] in ["remove"]:
                remove_box = pred_box
            else: # move
                if probe_context.split("**")[0].split()[-4]=="Box":
                    remove_box = pred_box
                else:
                    exist_box = pred_box

            # here we only consider examples where probe would have causally contributed to the object-box
            # tag signal by considering the last interaction an object has with any box
            is_last_op = all([is_last_operation(probe_context, " ".join(inf_result["prefix"][data_bunch_idx].split(" ")[:-3]), pred_obj) for pred_obj in pred_objs])
        if not is_last_op:
            continue

        # now we will consider each operation separately
        # for put/description, we want to see if target box's prediction match up with probe accuracy
        src_box = None
        if local_ops[pred_idx] in ["put", "description"]:
            tgt_box = exist_box
        elif local_ops[pred_idx] in ["remove"]:
            tgt_box = remove_box
        else: # move
            tgt_box, src_box = exist_box, remove_box

        for pred_obj in pred_objs:
            for box in range(7):
                inf_result_instance = inf_result.iloc[data_bunch_idx+box]
                text_predicted = 1 if pred_obj in inf_result_instance["parsed_original_answer"] else 0
                probe_predicted = pred[pred_idx, OTI_BOX(pred_obj, box)]
                if box == tgt_box:
                    matrices[f"{local_ops[pred_idx]}-tgt"][probe_predicted, text_predicted] += 1
                elif src_box is not None and box == src_box:
                    matrices[f"{local_ops[pred_idx]}-src"][probe_predicted, text_predicted] += 1
                else:
                    matrices[f"{local_ops[pred_idx]}-global"][probe_predicted, text_predicted] += 1

    for k, v in matrices.items():
        print(f"{k}:\n{v}")
    norm_matrices = {}
    for k, v in matrices.items():
        norm_matrices[k] = matrices[k]/ matrices[k].sum(axis=1)[:, np.newaxis]
    for k, v in norm_matrices.items():
        print(f"{k}:\n{v}")
    # pdb.set_trace()
    return


def analyze_probe_class_hierarchy(probe_path: str, variance_threshold: float=0.95, plot=True):
    weights = torch.load(os.path.join(probe_path, "checkpoint.ckpt"), map_location="cpu")['proj.weight'].numpy()
    h_dim = weights.shape[1]
    norm_matrix = np.zeros((3, 3))
    norm_std_matrix = np.zeros((3, 3))
    cosine_matrix = np.zeros((3, 3))
    ranks = [0, 0, 0]
    for src_cls in range(3):
        # reshape to [h_dim, #probes]
        cls_weights = weights.reshape(700, 3, h_dim)[:, src_cls].transpose(1, 0)
        U, S, Vt = np.linalg.svd(cls_weights, full_matrices=False)
        total_var = np.sum(S ** 2)
        explained = np.cumsum(S ** 2) / total_var
        plt.plot(explained)
        
        essential_rank = np.searchsorted(explained, variance_threshold) + 1
        ranks[src_cls] = essential_rank
        if plot:
            print(f"Class {src_cls} Essential rank (for {variance_threshold} variance):", essential_rank)
        if plot:
            plt.xlabel("Rank / Singular Component")
            plt.ylabel("Cumulative variance explained")
            plt.title(f"Rank / Shared Directions (essential rank for {variance_threshold} variance)={essential_rank})")
            plt.savefig(f"{probe_path}/essential_rank_cls={src_cls}.png")
            plt.close()

        U_k = U[:, :essential_rank]

        for tgt_cls in range(3):
            cosine_to_low_rank = []
            norm_ratio_to_low_rank = []
            tgt_probe_weights = weights.reshape(700, 3, h_dim)[:, tgt_cls].transpose(1, 0)
            for probe_idx in range(700):
                probe_weight = tgt_probe_weights[:, probe_idx]
                reconstructed_probe_weight = U_k @ (U_k.T @ probe_weight)
                norm_ratio = np.linalg.norm(reconstructed_probe_weight)/np.linalg.norm(probe_weight)
                cosine = cosine_similarity(np.expand_dims(probe_weight, 0), np.expand_dims(reconstructed_probe_weight,0))
                norm_ratio_to_low_rank.append(norm_ratio)
                cosine_to_low_rank.append(cosine)
            norm_matrix[src_cls, tgt_cls] = np.stack(norm_ratio_to_low_rank).mean()
            norm_std_matrix[src_cls, tgt_cls] = np.std(norm_ratio_to_low_rank)
            cosine_matrix[src_cls, tgt_cls] = np.stack(cosine_to_low_rank).mean()
    if plot:
        print(f"norm_matrix\n{norm_matrix}")
        print(f"norm_std_matrix\n{norm_std_matrix}")
    return ranks, norm_matrix, norm_std_matrix

def plot_probe_class_hierarchy_across_layers(probe_path_dir: str, output_dir:str, variance_threshold: float=0.95):
    df = []
    cls_labels = ["Non-exist", "Exist", "Removed"]
    for probe_file in tqdm(os.listdir(probe_path_dir)):
        layer = int(probe_file.split("_")[0].replace("layer", ""))
        probe_path = os.path.join(probe_path_dir, probe_file)
        rank, norm_matrix, norm_std_matrix = analyze_probe_class_hierarchy(probe_path, variance_threshold=variance_threshold, plot=False)
        data = {"Layer": layer}
        for src_cls in range(3):
            data[f"{cls_labels[src_cls]} Rank"] = rank[src_cls]
            for tgt_cls in range(3):
                data[f"{cls_labels[src_cls]} → {cls_labels[tgt_cls]}"] = norm_matrix[src_cls, tgt_cls]
        df.append(data)

    df = pd.DataFrame(df)
    model_name = probe_path_dir.split("/")[-3]
    dataset_name = probe_path_dir.split("/")[-4]
    probe_type = probe_path_dir.split("/")[-1]
    out_exp_dir = f"{output_dir}/phrase_probe_accuracy/{dataset_name}/{model_name}"

    # plot ranks across
    os.makedirs(out_exp_dir, exist_ok=True)
    df_melt = pd.melt(
        df, id_vars=["Layer"],
        value_vars=[f"{l} Rank" for l in cls_labels],
        var_name="Probe Class",
        value_name="Rank"
    )
    # for some reason there are these weird datapoints
    df_melt = df_melt[df_melt.Layer <= (len(os.listdir(probe_path_dir))+1)]
    fix_fonts(default=12)
    plt.close()
    ax = sns.lineplot(data=df_melt, x="Layer", y="Rank", hue="Probe Class")
    ax.legend(loc='lower right')
    plt.tight_layout()

    # plt.title("Box ID" if "number" in probe_type else "Object")
    plt.savefig(f"{out_exp_dir}/rank_{probe_type}.png")
    plt.savefig(f"{out_exp_dir}/rank_{probe_type}.pdf", dpi=600)
    plt.close()
    # plot reconstruction loss
    df_melt = pd.melt(
        df, id_vars=["Layer"],
        value_vars=[c for c in df.columns if "→" in c],
        var_name="Reconstruction Directions",
        value_name="Reconstruction Norm Ratio"
    )

    ax=sns.lineplot(data=df_melt, x="Layer", y="Reconstruction Norm Ratio", hue="Reconstruction Directions")
    # plt.title("Box ID" if "number" in probe_type else "Object")
    ax.legend(loc='lower right')
    plt.tight_layout()

    plt.savefig(f"{out_exp_dir}/low-rank_recon_{probe_type}.png")
    plt.savefig(f"{out_exp_dir}/low-rank_recon_{probe_type}.pdf", dpi=600)
    plt.close()
    return


def is_adjacent_objects(all_objs: list[str], objs: list[str])->bool:
    indices = set([save_index(all_objs, obj) for obj in objs])
    range_set = set(list(range(min(indices), max(indices) + 1)))
    return indices==range_set

def save_index(l: List[Any], item: Any) -> int:
    return l.index(item) if item in l else np.random.randint(-10000, -2)

def intervention_1put_error_categorization(intervention_path: str):
    df = pd.read_json(intervention_path, lines=True, orient='records')
    df_wrong = df[(df.intervention_rest_success == False) & (df.intervention_obj_success==True)]
    results =[]
    for i, row in df_wrong.iterrows():
        supporting_mechanisms = np.array([0.0, 0.0, 0.0])  # 1 is enum, 2 is oid, 3 is other
        world_objs = get_objects(row["prefix"])
        if len(row["intervened_answer_items"]) > len(row["orig_answer"]) and is_adjacent_objects(world_objs, row["intervened_answer_items"]):
            supporting_mechanisms[0] += 1.0
        new_objs = [o for o in row["intervened_answer_items"] if o not in row["orig_answer"]]
        for new_obj in new_objs:
            if any([np.abs(save_index(world_objs,og_obj) - save_index(world_objs,new_obj)) <= 1 for og_obj in row["orig_answer"]]):
                supporting_mechanisms[1] += 1.0
        if sum(supporting_mechanisms) > 0:
            supporting_mechanisms /= sum(supporting_mechanisms)
        else:
            supporting_mechanisms[-1] = 1.0
        results.append(supporting_mechanisms)
    results = np.stack(results)
    print(f"Error types: {results.mean(0)}, support={len(results)}")
    # pdb.set_trace()
    return results.mean(0)

def plot_intervention_1put_error_categorization_across_layer(intervention_path_pattern: str, output_dir:str):
    paths = glob.glob(intervention_path_pattern)
    file_pattern = intervention_path_pattern.split("/")[-1]
    df = []
    for path in paths:
        result = intervention_1put_error_categorization(path)
        layer = int(path.split("=")[-1].replace(".jsonl", ""))
        df.append({
            "layer": layer,
            "error_enum": result[0],
            "error_oid": result[1],
            "error_other": result[2],
        })
    df = pd.DataFrame(df)
    df_melt = pd.melt(
        df, id_vars=["layer"],
        value_vars=["error_enum", "error_oid", "error_other"],
        var_name="Error Type",
        value_name="Error Percentage"
    )
    # pdb.set_trace()
    result_dir = intervention_path_pattern.rsplit("/", maxsplit=1)[0]
    model_name = result_dir.split("/")[-1]
    dataset_name = result_dir.split("/")[-3]
    out_exp_dir = f"{output_dir}/phrase_probe_intervention_error/{dataset_name}/{model_name}"
    os.makedirs(out_exp_dir, exist_ok=True)
    fix_fonts()
    sns.lineplot(data=df_melt, x="layer", y="Error Percentage", hue="Error Type")
    # plt.title(f"error_pct,{file_pattern}")
    plt.tight_layout()
    plt.savefig(f"{out_exp_dir}/error_pct_{file_pattern}.png")
    plt.savefig(f"{out_exp_dir}/error_pct_{file_pattern}.pdf", dpi=600)



def get_probe_accuracy_across_context(pred_path: str, cache_path: str):
    """
    Analyze probe error relation to context length, if longer errors more,
    that may explain why CoT does better
    """
    # load groundtruth label and text strings
    with open(cache_path, "rb") as f:
        output = pickle.load(f)
    activation_indices, y, num_ops, counts, all_mentioned_objects, texts = output
    del activation_indices, output, counts, all_mentioned_objects
    local_num_ops = np.array([t.count(". ") for t in texts])
    global_num_ops_plus = np.array([t.count(". ")+t.count(", ") for t in texts])
    is_not_move = np.array([False if "Move" in t.split(".")[-1] else True for t in texts])
    y = np.stack([t.numpy() for t in y])

    # load probe predictions
    pred = np.load(pred_path)  # [15K, 700]

    # load probe masks
    masks = torch.load(cache_path.rsplit("/",1)[0]+"/test_masks_subset.pt")
    masks = {k: v.numpy() for k, v in masks.items()}
    # first let's just see correlationally, if number of local ops or global ops has correlation to probe accuracy
    # acc = (pred == y).mean(1)  # too high
    acc = ((y==pred) * masks["local_box_obj"]).sum(1) / masks["local_box_obj"].sum(1)
    binned_acc = np.array([acc[global_num_ops_plus == op].mean() for op in sorted(np.unique(global_num_ops_plus))])
    binned_cnt = np.array([len(acc[global_num_ops_plus == op]) for op in sorted(np.unique(global_num_ops_plus))])
    print(f"Cnt by global ops+: {binned_cnt}")
    print(f"Acc by global ops+: {binned_acc}")

    df = pd.DataFrame({
        "Accuracy": acc,
        "Phrase Index": global_num_ops_plus,
        "is_not_move": is_not_move,
    })
    # then let's look at where specifically in context the ops are has effect on accuracy
    return df

def plot_probe_accuracy_across_context(input_dict: Dict[str, List[str]], output_dir: str):
    all_df = []
    for probe_type, (pred_path, cache_path) in input_dict.items():
        df = get_probe_accuracy_across_context(pred_path, cache_path)
        df["Probe Type"] = probe_type
        all_df.append(df)

    all_df = pd.concat(all_df)
    fix_fonts()
    plt.figure(figsize=(7, 3.5))
    ax = sns.lineplot(data=all_df, y="Accuracy", x="Phrase Index", hue="Probe Type")
    sns.lineplot(data=all_df[all_df.is_not_move==True], y="Accuracy", x="Phrase Index", hue="Probe Type", linestyle='--', ax=ax)
    ax.set_ylabel('Accuracy') #  (Local Box-Object)
    ax.set_xlabel('')
    # manually set legend to differentiate the two accuracy settings
    handles, labels = ax.get_legend_handles_labels()
    handles, labels = handles[:2], labels[:2]
    handles.extend([
        Line2D([0], [0], color='black', linestyle='-', label='Full Data'),
        Line2D([0], [0], color='black', linestyle='--', label='No MOVE'),
    ])
    labels.extend(["Full Data", "No MOVE"])
    ax.legend(handles, labels).set_title("Probe Type")
    # fix the ticks to integer
    ax.set_xticks(np.arange(0, 19, 1))
    sec = ax.secondary_xaxis(location=0)

    # Midpoints of the two groups, put the names
    sec.set_xticks(
        [3, 12.5],
        labels=["\nDescriptions", "\nOperations"]
    )
    sec.tick_params(axis="x", length=0)

    # add vertical group separator
    sec2 = ax.secondary_xaxis(location=0)
    # Boundaries between groups
    sec2.set_xticks([-0.5, 6.5, 18.5], labels=[])
    sec2.tick_params(axis="x", length=35, width=1.5)
    sec2.set_xlabel('Phrase Index')

    plt.tight_layout()
    model_name = list(input_dict.values())[0][0].split("/")[-5]
    dataset_name = list(input_dict.values())[0][0].split("/")[-6]
    plt.savefig(f"{output_dir}/phrase_probe_accuracy/{dataset_name}/{model_name}/probe_accuracy_across_context.png", dpi=600)




if __name__ == '__main__':
    "<anonymized>"