import os
import csv
import pdb
import json
import regex as re
from typing import Dict, Tuple, List, Iterable, Optional, Any, Union

from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from .utils import load_object_names, load_box_data, remove_empty

NUM_BOXES=7
N_OBJECTS=100
MAX_QUERY_OPS=7
MAX_GLOBAL_OPS=12


def main_local_ops(
    prob_result_dir: str,
    box_test_data_path: str,
    object_vocab_path: str,
    out_dir: str,
    max_query_ops: int=MAX_QUERY_OPS,
    filter_move_content_out: bool=False,
):
    plt.rc('font', family='serif', serif=['Times New Roman'])

    # load box objects
    object_map, object_list = load_object_names(object_vocab_path)
    # load box data (full data not filtered by numops)
    df = remove_empty(pd.read_json(box_test_data_path, lines=True, orient="records"))

    # collect predictions across layers
    avg_accuracies = [] # [max_query_ops, n_data, n_layers]
    full_accuracies = [] # [max_query_ops, n_data, n_layers]
    non_triv_avg_accuracies= [] # [max_query_ops, n_data, n_layers]
    non_triv_full_accuracies = [] # [max_query_ops, n_data, n_layers]
    df_subs = {}
    df_invalid_indices = {}
    for op_idx in range(max_query_ops+1):

        # load the right data labels with the correct prior states
        df_sub = load_box_data(box_test_data_path, object_map, num_prior_state=-op_idx-1)
        df_subs[op_idx] = df_sub

        if filter_move_content_out:
            # if previous operation is move content out, then box content is empty, we want to remove these cases because they are easy
            f = df_sub.box_contents.apply(lambda x: int(sum(x)) > 0)
            # store the invalid data indices to remove data that does not meet requirements later
            df_sub_neg = df_sub[-f]
            df_invalid_indices[op_idx] = df_sub_neg.index
            print(f"removed {len(df_sub_neg)} datapoints that are moveContent out (empty box content)")

            # now keep only ones we want
            df_sub = df_sub[f]

        prob_result_op_dir = prob_result_dir if op_idx == 0 else f"{prob_result_dir}_prior_state_-{op_idx+1}"
        n_layers = len(os.listdir(prob_result_op_dir))

        op_avg_accuracies = [np.zeros(0) for _ in range(n_layers)]
        op_full_accuracies = [np.zeros(0) for _ in range(n_layers)]
        op_non_triv_avg_accuracies = [np.zeros(0) for _ in range(n_layers)]
        op_non_triv_full_accuracies = [np.zeros(0) for _ in range(n_layers)]

        for file in tqdm(os.listdir(prob_result_op_dir)):
            layer_idx = int(file.split('_')[0].replace('layer', ''))
            pred_file_path = os.path.join(prob_result_op_dir, file, "predictions.txt")
            if not os.path.isfile(pred_file_path):
                continue

            # pred shape = (n_data X n_objects)
            pred = pd.read_table(pred_file_path, sep=" ").to_numpy()

            if filter_move_content_out:
                pred = pred[f]
            assert len(df_sub) == len(pred), "prediction and datapoints len does not match up"

            # calculate accuracy across all objects
            box_content = np.vstack(df_sub.box_contents)
            avg_acc = (box_content == pred).mean(1)
            op_avg_accuracies[layer_idx-1] = avg_acc
            full_acc = ((box_content == pred).sum(1)==N_OBJECTS).astype(float)
            op_full_accuracies[layer_idx-1] = full_acc

            # calculate accuracy across objects mentioned in the example (non-trivial)
            mentioned = np.vstack(df_sub.all_mentioned_objects)
            non_triv_avg_acc = ((box_content == pred)*mentioned).sum(1)/mentioned.sum(1)
            op_non_triv_avg_accuracies[layer_idx-1] = non_triv_avg_acc
            non_triv_full_acc = ((box_content == pred)*mentioned).sum(1)==mentioned.sum(1).astype(float)
            op_non_triv_full_accuracies[layer_idx-1] = non_triv_full_acc

        avg_accuracies.append(np.vstack(op_avg_accuracies).T)
        full_accuracies.append(np.vstack(op_full_accuracies).T)
        non_triv_avg_accuracies.append(np.vstack(op_non_triv_avg_accuracies).T)
        non_triv_full_accuracies.append(np.vstack(op_non_triv_full_accuracies).T)

    # plot matrix as groups of heatmap, group by number of query ops of the datapoint
    out_dir = os.path.join(out_dir, "prob_results_across_layers_varying_query_ops")
    probe_split_name = prob_result_dir.split("/")[prob_result_dir.split("/").index("probing")-1]
    os.makedirs(os.path.join(out_dir, probe_split_name), exist_ok=True)
    for metric_name, metrics in {
        "avg_accuracy": avg_accuracies,
        "full_accuracy": full_accuracies,
        "non_triv_avg_accuracy": non_triv_avg_accuracies,
        "non_triv_full_accuracy": non_triv_full_accuracies,
    }.items():
        fig_size = [20,10] if max_query_ops == MAX_QUERY_OPS else [6,6]
        fig, axes = plt.subplots(ncols=max_query_ops+1, figsize=fig_size, sharey=True, gridspec_kw={"width_ratios": [i+1 for i in range(max_query_ops+1)]})

        # group by datapoints with different number of local operations
        for numops in range(max_query_ops+1):
            heatmap_metric = np.zeros((n_layers, max_query_ops+1))
            x_ticks = [" " for _ in range(max_query_ops+1)]

            # looking at each operation index probing accuracy
            for op_idx in range(min(numops+1, max_query_ops+1)):
                df_sub = df[df.numops >= op_idx]
                if filter_move_content_out:
                    df_sub = df_sub[~df_sub.index.isin(df_invalid_indices[op_idx])]
                assert len(df_sub) == len(metrics[op_idx]), "prediction and datapoints len does not match up"
                f = df_sub.numops == numops
                metrics_sub = metrics[op_idx][f]
                heatmap_metric[:, op_idx] = metrics_sub.mean(0)

                x_ticks[op_idx] = f"Prior State={op_idx*-1 - 1}"
            x_ticks[0] = "Final State"

            # invert the layers so it follows computational graph diagraph from bottom (embedding) to up (end of residual stream)
            show_color_bar = numops == max_query_ops
            sns.heatmap(heatmap_metric[:, : numops+1], cmap="YlGnBu", ax=axes[numops], vmin=0.5, vmax=1.0, cbar=show_color_bar)
            axes[numops].set_xticks(range(numops+1))
            axes[numops].set_xticklabels(x_ticks[: numops+1], rotation=70, fontsize=15)
            axes[numops].set_title(f"{numops}-op (n={f.sum()})", rotation=15, fontsize=15)
            # axes[numops].title.set_text(f"{numops}-op (n={f.sum()})")

            axes[numops].invert_yaxis()

        fig.supylabel("Layer Index", x=0.01,fontsize=25)
        # fig.supxlabel(f"{metric_name} for the n-th Operations (from the end)")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, probe_split_name, f"{metric_name}{'-filter_move_out' if filter_move_content_out else ''}{'' if max_query_ops==MAX_QUERY_OPS else 'maxOp'+str(max_query_ops)}.png"), dpi=900)
        plt.savefig(os.path.join(out_dir, probe_split_name, f"{metric_name}{'-filter_move_out' if filter_move_content_out else ''}{'' if max_query_ops==MAX_QUERY_OPS else 'maxOp'+str(max_query_ops)}.pdf"), dpi=900)


def main_global_ops(
    prob_result_dir: str,
    box_test_data_path: str,
    object_vocab_path: str,
    out_dir: str,
    max_query_ops: int=MAX_QUERY_OPS,
    filter_move_content_out: bool=False,
):

    """
    different from the other function, this reads in tensorboard.txt, which doesn't include non-trivial accuracies (yet)
    """
    # collect predictions across layers
    avg_accuracies= [] # [max_query_ops, n_layers]
    strat_avg_accuracies = [] # [max_query_ops, n_layers]

    for op_idx in range(max_query_ops+1):

        prob_result_op_dir = prob_result_dir if op_idx == 0 else f"{prob_result_dir}_prior_state_-{op_idx+1}"
        n_layers = len(os.listdir(prob_result_op_dir))

        op_avg_accuracies = [np.nan for _ in range(n_layers)]
        op_strat_avg_accuracies = [np.nan for _ in range(n_layers)]

        for f in tqdm(os.listdir(prob_result_op_dir)):
            # these may need to change depending on result file formating
            layer_idx = int(f.split('_')[0].replace('layer', ''))
            pred_file_path = os.path.join(prob_result_op_dir, f, "tensorboard.txt")
            if not os.path.isfile(pred_file_path):
                continue

            # pred shape = (n_data X n_objects)
            pred_res = json.load(open(pred_file_path))

            # get last epoch test accuracy
            op_avg_accuracies[layer_idx-1] = pred_res["test_acc_cont"][-1]
            op_strat_avg_accuracies[layer_idx-1] = pred_res["test_strat_acc_cont"][-1]

        avg_accuracies.append(np.vstack(op_avg_accuracies).T)
        strat_avg_accuracies.append(np.vstack(op_strat_avg_accuracies).T)

    # plot matrix as groups of heatmap, group by number of query ops of the datapoint
    out_dir = os.path.join(out_dir, "prob_results_across_layers_varying_global_ops")
    probe_split_name = prob_result_dir.split("/")[prob_result_dir.split("/").index("probing")-1]
    os.makedirs(os.path.join(out_dir, probe_split_name), exist_ok=True)

    ## plot avg accuracy
    plt.figure(figsize=[10, 10])
    heatmap_metric = np.zeros((n_layers, MAX_QUERY_OPS + 1))

    # group by datapoints with different number of local operations
    x_ticks = [" " for _ in range(MAX_QUERY_OPS + 1)]
    for op_idx in range(MAX_QUERY_OPS + 1):
        heatmap_metric[:, op_idx] = avg_accuracies[op_idx]
        x_ticks[op_idx] = f"{op_idx * -1 - 1}"

    sns.heatmap(heatmap_metric, cmap="YlGnBu", vmin=0.5, vmax=1.0)
    plt.xticks(range(len(x_ticks)), x_ticks)
    # invert the layers so it follows computational graph diagraph from bottom (embedding) to up (end of residual stream)
    plt.gca().invert_yaxis()

    plt.ylabel("Layer Index")
    plt.xlabel(f"Average accuracy for the n-th prior Operations (from the end)")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, probe_split_name, f"avg_accuracy.png"), dpi=900)
    plt.savefig(os.path.join(out_dir, probe_split_name, f"avg_accuracy.pdf"), dpi=900)

    ## plot stratefied accuracy
    # fig, axes = plt.subplots(ncols=MAX_QUERY_OPS+1, figsize=[30, 10], sharey=True)
    #
    # # group by datapoints with different number of local operations
    # for numops in range(MAX_QUERY_OPS+1):
    #     heatmap_metric = np.zeros((n_layers, MAX_QUERY_OPS+1))
    #     x_ticks = [" " for _ in range(MAX_QUERY_OPS+1)]
    #
    #     # looking at each operation index probing accuracy
    #     for op_idx in range(min(numops+1, max_query_ops+1)):
    #         df_sub = df[df.numops >= op_idx]
    #         assert len(df_sub) == len(metrics[op_idx]), "prediction and datapoints len does not match up"
    #         f = df_sub.numops == numops
    #         metrics_sub = metrics[op_idx][f]
    #         heatmap_metric[:, op_idx] = metrics_sub.mean(0)
    #
    #         x_ticks[op_idx] = f"prior_state={op_idx*-1 - 1}"
    #     # invert the layers so it follows computational graph diagraph from bottom (embedding) to up (end of residual stream)
    #     show_color_bar = numops == MAX_QUERY_OPS
    #     sns.heatmap(heatmap_metric, cmap="YlGnBu",
    #                 ax=axes[numops], vmin=0.5, vmax=1.0, cbar=show_color_bar)
    #     axes[numops].set_xticks(range(len(x_ticks)))
    #     axes[numops].set_xticklabels(x_ticks, rotation=70)
    #     axes[numops].title.set_text(f"Boxes with {numops} (local) operations (n={f.sum()})")
    #     axes[numops].invert_yaxis()
    #
    #
    # fig.supylabel("Layer Index")
    # fig.supxlabel(f"{metric_name} for the n-th Operations (from the end)")
    # plt.tight_layout()
    # plt.savefig(os.path.join(out_dir, probe_split_name, f"{metric_name}.png"))




if __name__ == '__main__':
    ROOT="<anonymized>"

    # llama405b
    main_local_ops(
        prob_result_dir=f"{ROOT}/probe_checkpoints/boxes_data_v1/llama3.1-405b/probing/state_binary_exclude_empty",
        box_test_data_path=f"{ROOT}/data/boxes-dataset-v1/few_shot_boxes_nso_exp2_max3/test-subsample-states-gpt.jsonl",
        object_vocab_path=f"{ROOT}/data/objects_with_bnc_frequency.csv",
        out_dir=f"{ROOT}/plots",
        max_query_ops=7
    )

    main_local_ops(
        prob_result_dir=f"{ROOT}/probe_checkpoints/boxes_data_v1/codellama-13b/probing/state_binary_exclude_empty",
        box_test_data_path=f"{ROOT}/data/boxes-dataset-v1/few_shot_boxes_nso_exp2_max3/test-subsample-states-gpt.jsonl",
        object_vocab_path=f"{ROOT}/data/objects_with_bnc_frequency.csv",
        out_dir=f"{ROOT}/plots",
        max_query_ops=3 # 7
    )




