import os
import json
import glob
import pdb
from tqdm import tqdm

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


def add_remove_obj_not_predicted_feature(row):
    removed_obj = [obj for obj in row["gold_items"] if f"Remove the {obj}" in row["prefix"]]
    remove_not_predicted = all([obj not in row["parsed_original_answer"] for obj in removed_obj])
    return remove_not_predicted


def print_errors(path: str):
    df = pd.read_json(path, lines=True, orient="records")
    total = len(df)
    if "correct" not in df.columns:
        df["correct"] = df.parsed_original_answer == df.gold_answer
        df["gold_items"] = df.gold_answer
    df["removed_obj_not_predicted"] = df.apply(add_remove_obj_not_predicted_feature, axis=1)

    # print(f"overall accuracy = {df.correct.mean()}")
    # print(f"failed cnt = {total - df.correct.sum()}")
    # print(f"failed due to remove tag = {df.removed_obj_not_predicted.sum()}")

    if "intervention_obj_success" not in df.columns:
        return df

    # print("out of all failure cases:")
    # print(f"full success cases: {len(df[(df.intervention_obj_success==True)&(df.intervention_rest_success==True)])}")
    # print(f"partial success cases: {len(df[(df.intervention_obj_success == True) & (df.intervention_rest_success == False)])}")

    # df = df[df.removed_obj_not_predicted == True]
    # print("out of previous failure cases on globally removed object:")
    # print(f"full success cases: {len(df[(df.intervention_obj_success == True) & (df.intervention_rest_success == True)])}")
    # print(f"partial success cases: {len(df[(df.intervention_obj_success == True) & (df.intervention_rest_success == False)])}")
    return df


def plot_across_layers(path_patterns: str, output_dir: str):
    paths = glob.glob(path_patterns)
    all_df = []
    for path in tqdm(paths):
        df = pd.read_json(path, lines=True, orient="records")
        total = len(df)
        if "correct" not in df.columns:
            df["correct"] = df.parsed_original_answer == df.gold_answer
            df["gold_items"] = df.gold_answer
        df["removed_obj_not_predicted"] = df.apply(add_remove_obj_not_predicted_feature, axis=1)
        df["full_success"] = (df.intervention_obj_success==True) & (df.intervention_rest_success==True)
        df["partial_success"] = (df.intervention_obj_success == True) & (df.intervention_rest_success == False)
        all_df.append(df)
    if len(all_df)== 0:
        print(f"No data to plot, {path_patterns}")
        return
    all_df = pd.concat(all_df)


    result_dir, file_pattern = path_patterns.rsplit("/", maxsplit=1)
    model_name = result_dir.split("/")[-1]
    dataset_name = result_dir.split("/")[-3]
    out_exp_dir = f"{output_dir}/phrase_probe_intervention_accuracy/{dataset_name}/{model_name}"
    os.makedirs(out_exp_dir, exist_ok=True)

    # full data
    df_melt = pd.melt(
        all_df, id_vars=["layers"],
        value_vars=["full_success", "partial_success"],
        var_name="Success Type",
        value_name="Success Value"
    )
    sns.lineplot(data=df_melt, x="layers", y="Success Value", hue="Success Type")
    plt.title(f"error_pct,{file_pattern}")
    plt.savefig(f"{out_exp_dir}/intervention_success_{file_pattern}.png")
    plt.close()
    # initially specific fail data
    df_melt = pd.melt(
        all_df[all_df.removed_obj_not_predicted == True], id_vars=["layers"],
        value_vars=["full_success", "partial_success"],
        var_name="Success Type",
        value_name="Success Value"
    )
    sns.lineplot(data=df_melt, x="layers", y="Success Value", hue="Success Type")
    plt.title(f"error_pct,{file_pattern}")
    plt.savefig(f"{out_exp_dir}/intervention_success_partial_{file_pattern}.png")
    plt.close()
    print(df_melt.groupby(["layers", "Success Type"])["Success Value"].mean())
    pdb.set_trace()

if __name__ == "__main__":
    base_pred="entity-tracking-gemma/outputs/behavioral_put_globally_removed_obj/boxes_altAlways_1put_1remove_1fixObj/codellama-13b/results_filter=True.jsonl"
    intervene = "entity-tracking-probing/results/boxes_altAlways_1put_1remove_1fixObj/phrase_probe_intervention/codellama-13b/intervention_CodeLlama-13b-hf_8bit_global-remove_null_object-op_alpha=1_fn=3.jsonl"
    out_dir = "entity-tracking-probing/plots"
    # print_errors(intervene)

    for probe in ["object", "number"]:
        for l in ["fn=", "n=", "ln"]:
            plot_across_layers(
                f"entity-tracking-probing/results/boxes_altAlways_1put_1remove_1fixObj/phrase_probe_intervention/codellama-13b/intervention_CodeLlama-13b-hf_8bit_global-remove_null_{probe}-op_alpha=1_{l}*.jsonl",
                out_dir
            )