from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np 


### PLOTTING


# def _plot_similarity(df, model="BAAI_bge-small-en-v1.5", dataset="paul_graham"):
#     rel_columns = [col for col in df.columns if "text_" in col]

#     mode = "insert" if "insert" in rel_columns[0] else "remove"
#     needle_keywords = list(set([col.split("_")[-3] for col in rel_columns]))
#     needle_posns = list(
#         set([float(col.split("_")[-2].replace("pos", "")) for col in rel_columns])
#     )
#     needle_sizes = list(
#         set([float(col.split("_")[-1].replace("sz", "")) for col in rel_columns])
#     )

#     for size in needle_sizes:
#         rel_columns = [col for col in df.columns if "cosine_similarity_" in col]
#         rel_columns = [
#             col
#             for col in rel_columns
#             if f"_{size}sz" in col or f"_{int(size)}sz" in col
#         ]
#         data = df[rel_columns].values

#         fig, ax = plt.subplots()
#         ax.boxplot(data)

#         # Adding titles and labels
#         ax.set_title(f"Similarity with Needle of Size {size}", pad=15)
#         ax.set_xlabel("Needle Placement", labelpad=15)
#         ax.set_ylabel("Cosine Similarity", labelpad=15)

#         ax.set_xticklabels(["Beginning", "Middle", "End"])

#         plt.tight_layout()
#         Path(f"./plots/{model}/{dataset}").mkdir(parents=True, exist_ok=True)
#         plt.savefig(f"./plots/{model}/{dataset}/{mode}_{size}_{model}.png")
#         plt.close()


# def create_model_plots(model_name):
#     # path = Path(f"./data/{model_name}")
#     path = Path(f"{DATA_PATH}/{model_name}")
#     assert path.exists(), f"Model {model_name} not found"
#     datasets = path.glob("*_insert.pkl")
#     combined = None
#     for dataset in datasets:
#         df = pd.read_pickle(dataset)
#         if combined is None:
#             combined = df
#         else:
#             combined = pd.concat([combined, df])
#     _plot_similarity(combined, model_name, dataset="combined")

# def graph_total_avgs(model_names, mode="insert"):
#     assert model_names
#     inserts = [0.05, 0.1, 0.2, 0.5, 1]
#     removes = [0.05, 0.1, 0.2, 0.5]
#     input_sizes = inserts if mode == 'insert' else removes
#     posns = [0, 0.5, 1]

#     total_avgs = get_total_avgs(model_names, posns, input_sizes, mode=mode)
#     for i in input_sizes:
#         # n_rows = (len(model_names) + 1) // 2
#         # n_colns = 2
#         fig, ax = plt.subplots()
#         data_by_size = total_avgs[i]
#         # plt.style.use('ggplot')
#         for model_idx, model_name in enumerate(model_names):
#             ax.plot(posns,  data_by_size[model_name], label=model_name)

#         if mode == 'insert':
#             ax.set_title(f"Similarity with Insertion Ablation of Size {i}", pad=15)
#         else:
#             ax.set_title(f"Similarity with Removal Ablation of Size {i}", pad=15)
#         ax.set_xlabel("Needle Placement", labelpad=5)
#         ax.set_ylabel("Cosine Similarity", labelpad=5)

#         plt.locator_params(axis='x', nbins=3)
#         ax.set_xticks(posns)
#         ax.set_xticklabels(["Beginning", "Middle", "End"])
#         # y_ticks = [0.5, .6, .7, 1] # list(range(.70, 1.05, .05))
#         # y_ticks = list(range(.7, 1.1, .1))

#         # print(f"y_ticks is: {y_ticks}")


#         plt.legend()
#         plt.tight_layout()
#         plot_dir = f"./plots/combined/combined"
#         Path(plot_dir).mkdir(parents=True, exist_ok=True)
#         plot_path = f"{plot_dir}/{mode}_{i}_combined.png"
#         plt.savefig(plot_path)
#         print(f"Plot saved at: {plot_path}")
#         plt.close()
def graph_total_avgs(model_names, mode="insert", average_by_group=False):
    assert model_names
    inserts = [0.05, 0.1, 0.2, 0.5, 1]
    removes = [0.05, 0.1, 0.2, 0.5]
    input_sizes = inserts if mode == 'insert' else removes
    posns = [0, 0.5, 1]

    # Get the average data for the models
    total_avgs = get_total_avgs(model_names, posns, input_sizes, mode=mode)

    # Loop over each input size and call the plot function
    for input_size in input_sizes:
        data_by_size = total_avgs[input_size]
        _graph_individual_by_input_size(
            input_size,
            posns,
            model_names,
            data_by_size,
            mode,
            average_by_group=average_by_group,
        )


model_to_posn_encoding = {
    "BAAI_bge-m3": "APE",
    "jinaai_jina-embeddings-v2-base-en": "ALIBI",
    "dwzhu_e5rope-base": "ROPE",
    "mosaicml_mosaic-bert-base-seqlen-1024": "ALIBI",
    "nomic-ai_nomic-embed-text-v1.5": "ROPE",
    "intfloat_e5-large-v2": "APE",
}


def _graph_individual_by_input_size(
    input_size, posns, model_names, data_by_size, mode, average_by_group=None
):
    fig, ax = plt.subplots()
    if average_by_group is None:
        raise ValueError("Set argument: average_by_group.")
    if average_by_group:
        # Dictionary to store data for each group
        group_data = {group: [] for group in group_colors}

        # Collect data for each group
        for model_name in model_names:
            group = model_to_posn_encoding.get(model_name, "DEFAULT")
            if group in group_data:
                group_data[group].append(data_by_size[model_name])

        # Compute and plot the average line for each group
        for group, data_list in group_data.items():
            if data_list:  # Ensure there is data for the group
                # Convert the list of lists into a NumPy array for easier averaging
                data_array = np.array(data_list)
                # Compute the average across models for each position
                avg_data = np.mean(data_array, axis=0)

                # Plot the average line for the group
                ax.plot(
                    posns,
                    avg_data,
                    label=f"{group} (Average)",
                    color=group_colors[group],
                    linestyle="-",  # Use solid line for average line
                    linewidth=2,  # Make the line thicker for better visibility
                )
    else:
        # Track the line style index for each group
        group_line_style_index = {group: 0 for group in group_colors}

        # Plot the data for each model
        for model_name in model_names:
            group = model_to_posn_encoding.get(
                model_name, "DEFAULT"
            )  # Get the group name for the model
            color = group_colors[group]  # Get the color for the group
            line_style = line_styles[
                group_line_style_index[group] % len(line_styles)
            ]  # Cycle through line styles

            # Plot the data with the same color for the group and a unique line style
            ax.plot(
                posns,
                data_by_size[model_name],
                label=model_name,
                color=color,
                linestyle=line_style,
            )

            # Increment the line style index for the group
            group_line_style_index[group] += 1

    # Set the title based on the mode
    if mode == "insert":
        ax.set_title(f"Similarity with Insertion Ablation of Size {input_size}", pad=15)
    else:
        ax.set_title(f"Similarity with Removal Ablation of Size {input_size}", pad=15)

    # Set axis labels and ticks
    ax.set_xlabel("Needle Placement", labelpad=5)
    ax.set_ylabel("Cosine Similarity", labelpad=5)
    plt.locator_params(axis="x", nbins=3)
    ax.set_xticks(posns)
    ax.set_xticklabels(["Beginning", "Middle", "End"])

    # Add legend and format layout
    plt.legend()
    plt.tight_layout()

    # Save the plot to a file
    plot_dir = f"./plots/combined/combined"
    Path(plot_dir).mkdir(parents=True, exist_ok=True)
    plot_path = f"{plot_dir}/{mode}_{input_size}_combined.png"
    plt.savefig(plot_path)
    print(f"Plot saved at: {plot_path}")
    plt.close()


### COMBINE DATA


def get_total_avgs(model_names, posns, input_sizes, mode="insert"):
    """Returns dictionary where each item is a dictionary by size.
    This inner dictionary has lists of length 3 for positions based on size key."""

    total = {}
    for model_name in model_names:
        df = _get_combined_df_of_model(model_name, mode=mode)
        data = _get_size_avg_by_df(df, mode, posns, input_sizes)
        total[model_name] = data

    # Reformat
    data_by_size_graph = {}
    for size in input_sizes:
        size_data = {}
        for model_name in model_names:
            size_data[model_name] = total[model_name][size]
        data_by_size_graph[size] = size_data
    return data_by_size_graph


def _get_combined_df_of_model(model_name, mode="insert"):
    """Also still combining all the datasets into one csv and getting the average of that"""

    path = Path(f"{DATA_PATH}/{model_name}")
    assert path.exists(), f"Model {model_name} not found"
    datasets = path.glob(f"*_{mode}.pkl")
    combined = None
    for dataset in datasets:
        df = pd.read_pickle(dataset)
        if combined is None:
            combined = df
        else:
            combined = pd.concat([combined, df])
    return combined


def _get_size_avg_by_df(df, mode, posns, input_sizes):
    size_avgs = {}
    for s in input_sizes:
        posn_avgs = __get_posn_avg_by_size(df, mode, posns, s)
        size_avgs[s] = posn_avgs
    return size_avgs


def __get_posn_avg_by_size(df, mode, posns, input_size):
    """Return list of 3 items representing positions for given input size."""
    assert not isinstance(input_size, list)
    avgs = []
    for p in posns:
        if mode == "insert":
            coln_name = f"cosine_similarity_needle_{mode}_lorem_{p}pos_{input_size}sz"
        else:
            coln_name = f"cosine_similarity_needle_{mode}_{p}pos_{input_size}sz"
        coln = df[coln_name]
        a = coln.mean()
        avgs.append(a)
    assert len(avgs) == 3
    return avgs


group_colors = {
    "APE": "blue",
    "ALIBI": "red",
    "ROPE": "green",
    "DEFAULT": "purple",
}
line_styles = ["-", "--", "-.", ":"]

# DATA_PATH = "./data/data"
DATA_PATH = "./data"


def main():

    model_dirs = filter(
        lambda x: x.is_dir(), Path(DATA_PATH).glob("*")
    )  # Note the path here jic
    models = list(map(lambda x: x.parts[-1], model_dirs))

    cs_model_keywords = ["embed-english-v3.0", "text-embedding-3-small"] # "embed-english-v3.0", 
    os_model_keywords = [
        # "junnyu_roformer_chinese_base",
        # "intfloat_e5-mistral-7b-instruct",
        # "google-bert_bert-base-uncased",
        "BAAI_bge-m3",
        "jinaai_jina-embeddings-v2-base-en",
        "dwzhu_e5rope-base",
        "mosaicml_mosaic-bert-base-seqlen-1024",
        "nomic-ai_nomic-embed-text-v1.5",
        "intfloat_e5-large-v2",
    ]
    model_keywords = cs_model_keywords + os_model_keywords

    average_by_group = True
    # average_by_group = False

    relevant_models = []
    for model in models: 
        if model in model_keywords:
            relevant_models.append(model)
        else: 
            print(f"WARNING: {model} not found")

    graph_total_avgs(
        relevant_models, mode="insert", average_by_group=average_by_group
    )  # insert or remove
    graph_total_avgs(
        relevant_models, mode="remove", average_by_group=average_by_group
    )  # insert or remove


if __name__ == "__main__":
    main()
