import argparse
import os
from pathlib import Path

import pandas as pd
import numpy as np

from matplotlib import pyplot as plt
import seaborn as sns

import sklearn.cluster as skclus

import tqdm

import yaml
import datetime
import hashlib

p = [
    "#000000",
    "#E69F00",
    "#56B4E9",
    "#009E73",
    "#FB6467FF",
    "#808282",
    "#F0E442",
    "#440154FF",
    "#0072B2",
    "#D55E00",
    "#CC79A7",
    "#C2CD23",
    "#918BC3",
    "#FFFFFF",
]

NAIVE_CUTOFF = 0.01
MIN_THRESHOLD = 0.0
MAX_THRESHOLD = 2 * 1
STEP_THRESHOLD = 0.01
SAVE_PATH = Path(
    "where/the/results/from/evaluate/are/stored/persistent_patterns"
)


class SP:
    def __init__(self, name, sp, category, group, id) -> None:
        self.name = name
        self.sp = sp
        self.category = category
        self.group = group
        self.id = id


# Graph labeling functions
################################################################################################################################
def get_class(name):
    names_pred = (
        patterns[(patterns["GraphName"] == name) & (patterns["Type"] == "Pred")][
            "DatasetName"
        ]
        .str.split(DATASET_TYPE)
        .to_list()
    )
    return names_pred[0][1]


def node_lab(l):
    if False and ("power-662-bus" in l or "power-685-bus" in l):
        return l
    else:
        return l


################################################################################################################################
################################################################################################################################


# Clustering distance and identification
################################################################################################################################
def max_abs_distance(x, y):
    r"""Maximum absolute point-wise distance between two 1d arrays.

    Parameters
    ----------
    x : np.array
        First array
    y : np.array
        Second array
    """
    return np.max(np.abs(x - y))


def naive_distance(p1, p2):
    return np.all(np.abs(p1 - p2) < NAIVE_CUTOFF)


def clust_lookup(graph_name1, graph_name2):
    return GROUPED_SPS[graph_name1].group == GROUPED_SPS[graph_name2].group


def are_same_class(p1, p2, metric="naive"):
    if metric == "naive":
        return naive_distance(p1[1], p2[1])
    elif metric == "clust":
        return clust_lookup(p1[0], p2[0])


################################################################################################################################
################################################################################################################################


def plot_stuff(elbow, path):
    sns.set_context("paper", font_scale=2)
    plt.figure(figsize=(19, 11))
    g = sns.barplot(data=elbow, x="threshold", y="number_clusters", color=p[3])

    g2 = sns.lineplot(
        data=elbow,
        x="threshold",
        y="number_clusters_diff",
        markers=True,
        dashes=False,
        marker="o",
        markersize=3.5,
        markeredgecolor=p[4],
        color=p[4],
        linewidth=1.5,
        ax=g.axes.twinx(),
    )

    g.set(xlabel="Threshold", ylabel="Number of Clusters")
    # Get current x-tick labels and positions
    current_labels = g.get_xticklabels()
    current_positions = range(len(current_labels))

    # Select which labels to display (e.g., every 5th label)
    new_positions = current_positions[::10]
    new_labels = [
        label.get_text() for i, label in enumerate(current_labels) if i in new_positions
    ]
    g.set_xticks(new_positions)
    g.set_xticklabels(new_labels)

    g2.set_ylabel(ylabel="Difference in Number of Clusters", rotation=270)
    g2.yaxis.set_label_coords(1.09, 0.5)
    for item in g.get_xticklabels():
        item.set_rotation(90)
    plt.tight_layout()
    plt.savefig(path, dpi=1200)
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--file",
        type=str,
        dest="file_path",
        required=True,
        help="Full path to CSV with predictions should come from evaluate_models.py with CSV_SAVE_PREDICTIONS.\n Should have the information required as setted by the other options. For example, if sreal is selected, the file should hold SPs of the sreal dataset.",
    )
    parser.add_argument(
        "--sp-type",
        type=str,
        choices=["True", "Pred"],
        dest="sp_type",
        default="real",
        required=True,
        help="Operate over true or predicted SPs?",
    )
    parser.add_argument(
        "--data-type",
        type=str,
        choices=["sreal", "mlreal", "nd", "d", "both_synth"],
        dest="data_type",
        default="sreal",
        required=True,
        help="Choose what is the type of dataset to operate on.",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        dest="model_name",
        required=True,
        help="The name/id of the model that made the predictions.",
    )
    parser.add_argument(
        "--calculate-overlap",
        type=bool,
        dest="calculate_overlap",
        default=False,
        required=False,
        help="Should the overlap between real and predictions be calcualted?",
    )
    parser.add_argument(
        "--threshold",
        type=float,
        dest="threshold",
        default=None,
        required=False,
        help="Threshold used for clustering. In theory, should only be used on Pred and should be the result of True. If not given it will be discovered.",
    )
    args, _ = parser.parse_known_args()

    assert args.data_type in args.file_path
    patterns = pd.read_csv(args.file_path, index_col=0)

    SAVE_PATH = SAVE_PATH / args.data_type
    print(SAVE_PATH)
    SAVE_PATH.mkdir(parents=True, exist_ok=True)

    inferred_threshold = False
    image_save_path = Path(
        os.path.join(
            SAVE_PATH,
            "clust_breakdown"
            + "_"
            + args.data_type
            + "_"
            + args.sp_type
            + "_"
            + args.model_name
            + ".pdf",
        )
    )

    details_save_path = Path(
        os.path.join(
            SAVE_PATH,
            "details_"
            + args.data_type
            + "_"
            + args.sp_type
            + "_"
            + args.model_name
            + ".yaml",
        )
    )

    global DATASET_TYPE
    DATASET_TYPE = args.data_type

    selected_patterns = []
    for i in range(patterns.shape[0]):
        if args.sp_type == "Pred" and patterns.iloc[i, -1] == "Pred":
            selected_patterns.append(
                (
                    patterns.iloc[i, -2],
                    patterns.iloc[i, :-3].to_numpy(),
                    patterns.iloc[i, -3],
                )
            )
        elif args.sp_type == "True" and patterns.iloc[i, -1] == "True":
            selected_patterns.append(
                (
                    patterns.iloc[i, -2],
                    patterns.iloc[i, :-3].to_numpy(),
                    patterns.iloc[i, -3],
                )
            )

    sps = []
    names = []
    categories = []
    for graph_name, sp, category in selected_patterns:
        sps.append(sp)
        names.append(graph_name)
        categories.append(category)

    X = np.array(sps)

    dists = np.asarray(
        [[max_abs_distance(X[i], X[j]) for i in range(len(X))] for j in range(len(X))]
    )

    if args.threshold is None:
        agg_clusts = {}

        for threshold in tqdm.tqdm(
            np.arange(MIN_THRESHOLD, MAX_THRESHOLD + STEP_THRESHOLD, STEP_THRESHOLD)
        ):
            threshold = np.round(threshold, 5)
            agg_clust = skclus.AgglomerativeClustering(
                n_clusters=None,
                metric="precomputed",
                linkage="complete",
                distance_threshold=threshold,
            )
            agg_clust.fit(dists)
            agg_clusts[str(threshold)] = agg_clust

        thres = []
        n_clst = []
        for threshold in agg_clusts.keys():
            n_clst.append(agg_clusts[threshold].n_clusters_)
            thres.append((threshold))

        elbow = pd.DataFrame({"threshold": thres, "number_clusters": n_clst})
        elbow["threshold_num"] = elbow["threshold"].apply(lambda x: float(x))
        elbow["number_clusters_diff"] = [0] + [
            i - j
            for i, j in zip(elbow["number_clusters"], elbow["number_clusters"][1:])
        ]

        plot_stuff(elbow, image_save_path)

        picked_thresold = sorted(
            zip(elbow["number_clusters_diff"], elbow["threshold_num"]),
            reverse=True,
            key=lambda x: (x[0], -x[1]),
        )[0][1]
        print(f"Picked {picked_thresold}")
        cluster_result = agg_clusts[str(picked_thresold)]
        inferred_threshold = True
    else:
        image_save_path = None
        picked_thresold = args.threshold
        print(f"Using {picked_thresold}")
        cluster_result = skclus.AgglomerativeClustering(
            n_clusters=None,
            metric="precomputed",
            linkage="complete",
            distance_threshold=picked_thresold,
        )
        cluster_result.fit(dists)

    details = {
        "metadata": {
            "run_date": datetime.datetime.now(),
            "image_path": str(image_save_path),
            "model_name": args.model_name,
            "data_type": args.data_type,
            "sp_type": args.sp_type,
            "threshold_inferred": inferred_threshold,
        },
        "results": {
            "threshold": picked_thresold,
            "max_num_patterns_allowed": len(selected_patterns),
            "number_patterns_inferred": cluster_result.n_clusters_,
            "clusters_size_stats": pd.Series(
                np.unique(cluster_result.labels_, return_counts=True)[1]
            )
            .describe()
            .to_dict(),
        },
    }
    with open(details_save_path, "w") as yaml_file:
        yaml.dump(details, yaml_file, default_flow_style=False)

    with open(details_save_path, "rb") as f:
        file_data = f.read()

    # Create SHA-256 checksum
    checksum = hashlib.sha256(file_data).hexdigest()

    with open(
        os.path.join(
            os.path.dirname(details_save_path), details_save_path.stem + ".sha256"
        ),
        "w",
    ) as f:
        f.write(checksum)

    id = 0
    global GROUPED_SPS
    GROUPED_SPS = {}
    for name, group, cat, sp in zip(names, cluster_result.labels_, categories, sps):
        GROUPED_SPS[name] = SP(name, sp, cat, group, id)
        id += 1
