import os
import pdb
import json
from typing import List, Optional, Tuple, Dict, Union, Any
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def fix_fonts(title=20, label=20, xtick=15, ytick=15, default=15):
    # Set the global font family to 'Times New Roman'
    # keep running into
    plt.rc('font', family='serif', serif=['Times New Roman'])

    # Set the global default font size (e.g., to 14)
    plt.rcParams["font.size"] = default
    plt.rcParams["xtick.labelsize"] = xtick  # Optional: specific size for x-axis ticks
    plt.rcParams["ytick.labelsize"] = ytick  # Optional: specific size for y-axis ticks
    plt.rcParams["axes.labelsize"] = label  # Optional: specific size for axis labels
    plt.rcParams["axes.titlesize"] = title  # Optional: specific size for plot titles



def plot_multi_put(exp_dir: str):
    rows = []
    for f in os.listdir(exp_dir):
        if f == "lineplot.png":
            continue
        total_cnt = int(f.split("_")[0].replace("total", ""))
        query_cnt = int(f.split("_")[1].replace("query", "").replace(".json", ""))

        with open(os.path.join(exp_dir, f), "r") as f:
            data = json.load(f)
        for row in data["full_results"]:
            rows.append({
                "total_put_count": total_cnt,
                "query_put_count": query_cnt,
                "accuracy": row.get("avg_obj_accuracy",row.get("accuracy")),
            })

    df = pd.DataFrame(rows)

    sns.lineplot(x="total_put_count", hue="query_put_count", y="accuracy", data=df, marker="X")

    # add random guess baselines
    baselines = []
    for n in range(len(set(df.total_put_count.tolist()))):
        for i in range(1,5):
            baselines.append({
                "total_put_count": n,
                "query_put_count": f"random {i} guess accuracy",
                "accuracy": i/(7+n),
            })
    baselines_df = pd.DataFrame(baselines)
    sns.lineplot(x="total_put_count", hue="query_put_count", y="accuracy", data=baselines_df, marker="o",)

    plt.savefig(os.path.join(exp_dir, "lineplot.png"))


def plot_global_local_remove(exp_dir: str, ax=None):
    rows = []
    for f in os.listdir(exp_dir):
        if f == "hist.png":
            continue
        removal_target = f.replace(".json", "")
        with open(os.path.join(exp_dir, f), "r") as f:
            data = json.load(f)
        for row in data["full_results"]:
            correct_ctf_label = [l for l in row["labels"] if l !=row["ctf_label"]]
            rows.append({
                "REMOVE Box": removal_target.capitalize(),
                "Logit Diff": row["logit_diff"],
                "Model Correctness": "Correct" if row["ctf_argmax_token"].strip().lower() in correct_ctf_label else "Incorrect",
            })
    df = pd.DataFrame(rows)
    model_name = exp_dir.split("/")[-1]

    fix_fonts()
    plt.figure(figsize=(7, 3.5))

    axis = sns.boxplot(data=df, x="REMOVE Box", y="Logit Diff", hue="Model Correctness", hue_order=["Incorrect", "Correct"], ax=ax)
    axis.get_legend().set_title("")
    plt.axhline(y=0, color='black', linestyle='--', linewidth=1)  #

    if ax is None:
        plt.title(f"REMOVE Logit Argmax Accuracy={(df['Model Correctness']=='Correct').mean():.2f}")
        plt.tight_layout()
        plt.savefig(os.path.join(exp_dir, "hist.png"), dpi=600)
    else:
        ax.set_title(f"{model_name}, removal_accuracy={(df['Model Correctness']=='Correct').mean():.2f}")


def plot_global_local_remove_multi_model(exp_dirs: List[str]):
    n_cols = len(exp_dirs)
    n_rows = 1
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(6 * n_cols, 4 * n_rows),
        squeeze=False
    )
    # pdb.set_trace()
    for i, exp_dir in enumerate(exp_dirs):
        plot_global_local_remove(exp_dir, axes[0, i])

    # Optional: single shared legend
    # handles, labels = axes.flat[0].get_legend_handles_labels()
    # fig.legend(handles, labels, loc="upper right")
    # for ax in axes.flat:
    #     if ax.get_legend():
    for i in range(n_cols):
        if i != n_cols - 1:
            axes[0,i].get_legend().remove()

    plt.tight_layout()
    plt.savefig(Path(exp_dirs[0]).parent / "joint_hist.png")



def plot_global_local_remove_2_v_0_shot_model_family(exp_dirs: List[str], family_name: str):
    plt.rc('font', family='serif', serif=['Times New Roman'])

    # plot a 2x3 grid of the above experiments
    n_cols = len(exp_dirs) // 2
    fig, axs = plt.subplots(2, n_cols, figsize=(12, 5))
    for i, exp_dir in enumerate(exp_dirs):
        # if not exists, leave a empty plot
        if not os.path.exists(exp_dir):
            # leave an empty plot, with boundary
            axs[i // n_cols, i % n_cols].spines['top'].set_visible(True)
            axs[i // n_cols, i % n_cols].spines['right'].set_visible(True)
            axs[i // n_cols, i % n_cols].spines['left'].set_visible(True)
            axs[i // n_cols, i % n_cols].spines['bottom'].set_visible(True)
            # add title, but indicate missing
            axs[i // n_cols, i % n_cols].set_title(f"Missing: {exp_dir.split('/')[-1]}")
            continue

        rows = []
        for f in os.listdir(exp_dir):
            mode = "Two-Shot" if "two-shot" in exp_dir else "Zero-Shot"
            if f == "hist.png":
                continue
            removal_target = f.replace(".json", "")
            with open(os.path.join(exp_dir, f), "r") as f:
                data = json.load(f)
            for row in data["full_results"]:
                correct_ctf_label = [l for l in row["labels"] if l != row["ctf_label"]]
                rows.append({
                    "REMOVE Box": removal_target.capitalize(),
                    "Logit Diff": row["logit_diff"],
                    "Model Correctness": "Correct" if row["ctf_argmax_token"].strip() in correct_ctf_label else "Incorrect",
                })
        df = pd.DataFrame(rows)
        model_name = exp_dir.split("/")[-1]

        sns.boxplot(data=df, x="REMOVE Box", y=f"Logit Diff",
                    hue="Model Correctness", hue_order=["Incorrect", "Correct"], ax=axs[i // n_cols, i % n_cols])
        axs[i // n_cols, i % n_cols].axhline(y=0, color='black', linestyle='--', linewidth=1)  #
        axs[i // n_cols, i % n_cols].get_legend().set_title("")
        axs[i // n_cols, i % n_cols].set_title(f"{model_name}, Logit Acc.={(df['Model Correctness']=='Correct').mean():.2f}")
        # indicate: top row is two-shot, bottom row is zero-shot on Y axis

        if i % n_cols == 0:
            axs[i // n_cols, i % n_cols].set_ylabel(f"{mode}\nLogit Diff")
        else:
            axs[i // n_cols, i % n_cols].set_ylabel("")
        # indicate: left column is Llama-3.1-8B, middle column

    plt.tight_layout()
    plt.savefig(f"../outputs/behavioral_global_local_remove/{family_name}_family.png", dpi=600)



if __name__ == "__main__":
    # plot_multi_put("entity-tracking-gemma/outputs/behavioral_multi_put/codellama-13b")
    plot_global_local_remove("entity-tracking-gemma/outputs/behavioral_global_local_remove/codellama-13b")
    # plot_global_local_remove_multi_model(exp_dirs=[
    #     "entity-tracking-gemma/outputs/behavioral_global_local_remove/gemma-2-2b",
    #     "entity-tracking-gemma/outputs/behavioral_global_local_remove/codellama-13b",
    #     "entity-tracking-gemma/outputs/behavioral_global_local_remove/Llama-3.1-70B",
    # ])

    LOG_DIR = "entity-tracking-gemma/outputs/behavioral_global_local_remove"

    # # Plot Llama Family: 2*3
    # llama_exp_dirs = [
    #     os.path.join(LOG_DIR, "two-shot", "meta-llama", "Llama-3.1-8B"),
    #     os.path.join(LOG_DIR, "two-shot", "meta-llama", "Llama-3.1-70B"),
    #     os.path.join(LOG_DIR, "two-shot", "deepseek-ai", "DeepSeek-R1-Distill-Llama-8B"),
    #     os.path.join(LOG_DIR, "zeroshot", "meta-llama", "Llama-3.1-8B"),
    #     os.path.join(LOG_DIR, "zeroshot", "meta-llama", "Llama-3.1-70B"),
    #     os.path.join(LOG_DIR, "zeroshot", "deepseek-ai", "DeepSeek-R1-Distill-Llama-8B")
    # ]
    # plot_global_local_remove_2_v_0_shot_model_family(llama_exp_dirs, "llama")
    #
    # # Plot Qwen Family: 2 * 3
    # qwen_exp_dirs = [
    #     os.path.join(LOG_DIR, "two-shot", "Qwen", "Qwen3-1.7B"),
    #     os.path.join(LOG_DIR, "two-shot", "Qwen", "Qwen3-4B"),
    #     os.path.join(LOG_DIR, "two-shot", "Qwen", "Qwen3-8B"),
    #     os.path.join(LOG_DIR, "zeroshot", "Qwen", "Qwen3-1.7B"),
    #     os.path.join(LOG_DIR, "zeroshot", "Qwen", "Qwen3-4B"),
    #     os.path.join(LOG_DIR, "zeroshot", "Qwen", "Qwen3-8B")
    # ]
    # plot_global_local_remove_2_v_0_shot_model_family(qwen_exp_dirs, "qwen")
    #
    # # Plot Gemma Family: 2 * 4
    # gemma_exp_dirs = [
    #     os.path.join(LOG_DIR, "two-shot", "google", "gemma-2-2b"),
    #     os.path.join(LOG_DIR, "two-shot", "google", "gemma-2-9b"),
    #     os.path.join(LOG_DIR, "two-shot", "google", "gemma-3-1b-pt"),
    #     os.path.join(LOG_DIR, "two-shot", "google", "gemma-3-4b-pt"),
    #     os.path.join(LOG_DIR, "zeroshot", "google", "gemma-2-2b"),
    #     os.path.join(LOG_DIR, "zeroshot", "google", "gemma-2-9b"),
    #     os.path.join(LOG_DIR, "zeroshot", "google", "gemma-3-1b-pt"),
    #     os.path.join(LOG_DIR, "zeroshot", "google", "gemma-3-4b-pt"),
    # ]
    # plot_global_local_remove_2_v_0_shot_model_family(gemma_exp_dirs, "gemma")
    #
