import json
import os
import pdb
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Union, Literal
import argparse
from functools import reduce

from tqdm import tqdm

from jaxtyping import Float

import numpy as np


from run_subspace_patching import load_single_experiment, plot_single_experiment_full

# these libs have to be imported after nnsight import from run_subspace_patching
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def fix_fonts(title=20, label=20, xtick=15, ytick=15, default=15):
    # Set the global font family to 'Times New Roman'
    # keep running into
    plt.rc('font', family='serif', serif=['Times New Roman'])

    # Set the global default font size (e.g., to 14)
    plt.rcParams["font.size"] = default
    plt.rcParams["xtick.labelsize"] = xtick  # Optional: specific size for x-axis ticks
    plt.rcParams["ytick.labelsize"] = ytick  # Optional: specific size for y-axis ticks
    plt.rcParams["axes.labelsize"] = label  # Optional: specific size for axis labels
    plt.rcParams["axes.titlesize"] = title  # Optional: specific size for plot titles


def plot_subspace_similarity(exp_dirs: List[str], exp_names: List[str]):
    dfs = []
    for i, exp_dir in enumerate(exp_dirs):
        df = load_single_experiment(exp_dir, load_mask=True)
        df = df[df.patch_type=="subspace"]
        df["exp_name"] = exp_names[i]
        df["Rank"] = df["rank"]
        df["Layer"] = df["layer"]
        dfs.append(df)

    df= pd.concat(dfs)
    fix_fonts()
    plt.figure(figsize=(7, 3.5))
    ax = sns.lineplot(x="Layer", y="Rank", hue="exp_name", data=df)#, legend=False) #, style=1)

    # calculate intersections
    agg_df = []
    for (layer, ), gb in df.groupby(["layer"]):
        inter = sum(reduce(np.logical_and, gb["mask"].tolist()))
        union = sum(reduce(np.logical_or, gb["mask"].tolist()))
        min_rank = min(np.array(gb["mask"].tolist()).sum(1))
        agg_df.append({
            # "intersection_size": inter,
            # "union_size": union,
            # "IoU": inter/(union+1e-20),
            "Overlap Coefficient": inter/(min_rank+1e-20),
            "Layer": layer,
        })
    agg_df = pd.DataFrame(agg_df)
    # ax = sns.lineplot(x="layer", y="intersection_size", data=agg_df, label="Intersection")

    ax2 = ax.twinx()
    # sns.lineplot(x="layer", y="IoU", data=agg_df, label="IoU", ax=ax2, color="black", marker="o")
    sns.lineplot(x="Layer", y="Overlap Coefficient", data=agg_df, label="Overlap Coefficient", ax=ax2, color="gray", marker="o")
    # combine legends of two axis
    h1, l1 = ax.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    plt.legend(h1 + h2, l1 + l2)
    # plt.title("Rank Overlap")
    ax.get_legend().remove()
    exp_dir_name = exp_dirs[-1].split("/")[-1]
    # plt.title(exp_dir_name)
    plt.tight_layout()


    plot_path = Path(exp_dirs[-1]).parent.joinpath(f"{exp_dir_name}_rank_sim.png").resolve()
    print(f"Saving plot to: {plot_path}")
    # pdb.set_trace()
    plt.savefig(plot_path, dpi=600)
    plt.savefig(str(plot_path).replace(".png", ".pdf"), dpi=600)
    return


# def plot_subspace_similarity_all(exp_dirs: List[str], exp_names: List[str]):

if __name__ == "__main__":
    plot_subspace_similarity(
        # [
        #     "entity-tracking-gemma/outputs/nnsight_patch_noop/gemma-2-2b/dcm_pos_phrase_ctf_op/subspace_lamb6",
        #     "entity-tracking-gemma/outputs/nnsight_patch_1put_1put_irrelevant/gemma-2-2b/dcm_pos_phrase_ctf_op/subspace_noopBasis_lamb6"
        # ],
        [
            "entity-tracking-gemma/outputs/nnsight_patch_1put_1put_irrelevant_noFixObj/gemma-2-2b/dcm_pos_phrase_ctf_op/subspace_description",
            "entity-tracking-gemma/outputs/nnsight_patch_1put_1put_irrelevant_noFixObj/gemma-2-2b/dcm_pos_phrase_ctf_op/subspace_put"
        ],
        # [
        #     "entity-tracking-gemma/outputs/nnsight_patch_1put_1put_irrelevant/codellama-13b/dcm_pos_phrase_ctf_op/subspace_description",
        #     "entity-tracking-gemma/outputs/nnsight_patch_1put_1put_irrelevant/codellama-13b/dcm_pos_phrase_ctf_op/subspace_put"
        # ],
        exp_names=["description", "put"],
    )
