"""
Create the label file for graphs.
"""

import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

import hephaestus.label_generation.utils.label_utils as label_utils
import hephaestus.utils.general_utils as hutils
import hephaestus.utils.load_general_config as hconfig

LABEL_DIR = Path(hconfig.LABEL_DIR)
SCORE_DIR = Path(hconfig.SCORE_DIR)


def normalize_z_score(zscore, denom):
    if denom == 0:
        return 0
    # Use 0 when occ in original is small but higher than the 0 from the random model
    return zscore / denom if zscore != np.inf else 0


def calculate_denominator(zscores, limits):
    """Calculates de denom for normalization 10.1126/science.1089167"""
    zscores_of_interest = zscores[limits[0] : limits[1]]
    results = [zscore if zscore != np.inf else 0 for zscore in zscores_of_interest]
    return np.sqrt(np.sum(np.square(results)))


def build_row(zscores, network_name, target_value, blc_logger):
    """
    Build a row for the CSV file of `create_csv_and_image` with normalized z-scores.
    Uses the files generated by create_subgraph_score.py.
    :param `EnumType` `zscores`: List of zscores for a `network_name`.
    :param `str` `network_name`: A graph name.
    :param `float` `target_value`: The old target for the `network_name`.

    Note: `target_value` is a legacy parameter and in the foreseeable future should is handled by `create_csv_and_image`.
    """

    row = [network_name, target_value]

    # Add NaN in all positions of files that have at least a NaN
    has_np_nans = np.any(np.isnan(zscores))
    has_str_nans = np.any(np.apply_along_axis(lambda x: x == "nan", 0, zscores))
    if has_np_nans or has_str_nans:
        blc_logger.warning(f"Detected at least a NaN in {network_name}!")
        blc_logger.warning("Putting all z-scores to NaN!")
        for _ in zscores:
            row.append(np.nan)
        return row

    denoms = []
    for limits in hconfig.MARGINS_ZSCORE:
        denoms.append(calculate_denominator(zscores, limits))

    margin_index = 0
    for i, zscore in enumerate(zscores):
        if i >= hconfig.MARGINS_ZSCORE[margin_index][1]:
            margin_index += 1
        row.append(normalize_z_score(zscore, denoms[margin_index]))

    return row


def create_csv_and_image(dataset_name, original_labels, blc_logger):
    """
    Build the CSV with the morif labels of all graphs.
    Build an image to see mean and std of motif significance.
    Uses the files generated by create_subgraph_score.py.
    :param `str` `dataset_name`: A dataset name used to select the correct graph names e.g. ndEBA
    :param `np.array` `original_labels`: The old original labels of each graph from `dataset_name`.

    Note: `original_labels` is a legacy parameter and in the foreseeable future should be an empty `np.array`.
    """

    files_of_interest = []
    # No versions of deterministic were generated with sorted (hopefully, pls linux, this is not a problem)
    for f in sorted(os.listdir(SCORE_DIR)):
        if dataset_name in f:
            files_of_interest.append(f)

    num_rows = len(files_of_interest)  # As many rows as files with scores

    # TODO: correct for cases like else statement
    if num_rows != original_labels.shape[0]:
        if original_labels.shape[0] == 0:  # In case there are no old labels
            original_labels = np.full((num_rows,), np.nan)
        else:
            blc_logger.error(
                "This error message should NEVER be triggered in the foreseeable experiences!"
            )
            blc_logger.error("Num of result files != Num graphs.")
            blc_logger.error("Does the dataset has a label p/graph or labels p/node?")
            blc_logger.error("e.g. Datasets for node classification like KarateClub")
            sys.exit(1)

    rowcnt = 0
    data = []
    for raw_graph_score_file in files_of_interest:
        result_profile = pd.read_csv(
            os.path.join(SCORE_DIR, raw_graph_score_file), skipinitialspace=True
        )

        # Has to have all subgraphs, sanity check
        if result_profile.shape[0] != hconfig.NUM_SUBGRAPHS:
            blc_logger.error("Subgraphs without z-score.")
            blc_logger.error(f"Skipping, {raw_graph_score_file}")
            blc_logger.error(
                f"Got {result_profile.shape[0]}, expected {hconfig.NUM_SUBGRAPHS}"
            )
            raise AssertionError("[build_label_csv] Graph profile is missing z-scores!")

        graph_name = hutils.get_graph_name(raw_graph_score_file, with_extension=True)
        data.append(
            build_row(result_profile["z_score"], graph_name, original_labels[rowcnt], blc_logger)
        )
        rowcnt += 1

    cols = ["GraphName", "OldTarget"]
    for i in range(hconfig.NUM_SUBGRAPHS):
        cols.append("Subgraph" + str(i))

    pd_data = pd.DataFrame(data=data, columns=cols)
    pd_data.sort_values(by="GraphName", ignore_index=True, inplace=True)

    # Should not do anything because we return row instead of np.row
    # But this stays as solution on top of the solution
    for s in pd_data.columns:
        if "GraphName" not in s:
            pd_data[s].replace("nan", np.nan, inplace=True)

    pd_data.to_csv(
        path_or_buf=os.path.join(LABEL_DIR, dataset_name + "_labels" + ".csv"),
        index=False,
    )

    pd_data = pd.melt(
        pd_data,
        id_vars=["GraphName", "OldTarget"],
        value_vars=cols,
        var_name="Type",
        value_name="Norm_Z_Score",
    )

    # print(pd_data[pd_data["GraphName"] == "networkx.gaussian_random_partition_graph+graph+1167+param+0+cycle+1"])
    # Substitute NaN in OldTarget for hue and drop rows with NaN in subgraphs for plotting
    pd_data["OldTarget"] = pd_data["OldTarget"].apply(
        lambda x: -1 if np.isnan(x) else x
    )
    pd_data = pd_data.dropna(ignore_index=True)

    g = sns.pointplot(
        data=pd_data[pd_data.notna()],
        x="Type",
        y="Norm_Z_Score",
        hue="OldTarget",
        alpha=1,
        errorbar="sd",
    )
    g.set(xticklabels=[])
    g.set_title(dataset_name)
    plt.savefig(
        os.path.join(LABEL_DIR, dataset_name + "_labels" + ".svg"),
        format="svg",
        dpi=800,
    )
    g.clear()


def build_missing_file(error_call, dataset_name, blc_logger):
    """
    Build dummy file with NaNs for each subgraph size whose result could not be calcuated.
    Assumes that the final score file `DATASETNAME@GRAPHNAME.score` exists for all graphs
    whose score will be 'calcualted' and that said file as the PREAMBLE from `label_utils.py`
    already written. This typically handled by a earlier call to `create_subgraph_score.merge_files()`.

    :param `list` `error_call`: list with names of files whose scores could not be calculated.
    :param `str` `dataset_name`: name of the dataset whose graphs in `error_call` belong to.
    """

    # returns .../DATASETNAME@GRAPHNAME.score
    score_file_name = error_call[0].split("-size")[0]
    if not os.path.exists(score_file_name):
        raise AssertionError(
            "[build_label_csv] Building missing file without preexiting destination."
        )

    content_of_file = ""
    for file in error_call:  # for DATASETNAME@GRAPHNAME.score-sizeSIZE in error_call
        size_not_completed = file.split("@")[1].split("size")[1]
        # file.split(dataset_name)[1]
        graph_name = hutils.get_graph_name(file, with_extension=True)

        blc_logger.info(
            f"Generating {dataset_name + graph_name}, size {size_not_completed} since call failed"
        )

        if size_not_completed == "3":
            content_of_file += label_utils.build_size_3_string()
        elif size_not_completed == "4":
            content_of_file += label_utils.build_size_4_string()

    # if len(error_call) == len(hconfig.SUBGRAPH_SIZE): # Failed for all sizes
    mode = "a"
    with open(os.path.join(score_file_name), mode=mode, encoding="utf-8") as f:
        f.write(content_of_file)
