"""The dreaded utilities module, home to many a weird function."""

import os
import inspect
import pickle
import torch
import random
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel


#  ╭──────────────────────────────────────────────────────────╮
#  │ Diffusion Functions                                      │
#  ╰──────────────────────────────────────────────────────────╯


def calculate_diffusion_operator(X, epsilon):
    """Calculate diffusion operator on data."""
    # K: Affinity matrix defined by the kernel
    # Q: Vector of degrees (row sums)
    # P: D^-1 K (diffusion operator)
    K = rbf_kernel(X, gamma=1.0 / epsilon)
    Q = np.sum(K, axis=1)
    P = np.diag(1.0 / Q) @ K

    return P


def von_neumann_entropy(P, T):
    """Calculate von Neumann entropy of a diffusion operator.

    Parameters
    ----------
    P : array_like
        Diffusion operator

    T : int
        Maximum step number for the diffusion operator

    Returns
    -------
    np.array
        Array of entropy values
    """
    eigenvalues = np.abs(np.linalg.eigvalsh(P))
    eigenvalues_ = np.copy(eigenvalues)

    entropy = []

    for _ in range(T):
        probabilities = eigenvalues_ / np.sum(eigenvalues_)
        probabilities += 1e-8

        entropy.append(-np.sum(probabilities * np.log(probabilities)))

        eigenvalues_ = eigenvalues_ * eigenvalues

    return np.asarray(entropy)


#  ╭──────────────────────────────────────────────────────────╮
#  │ Distance covariance and distance correlation functions.  │
#  ╰──────────────────────────────────────────────────────────╯

# Originally taken from https://github.com/Pseudomanifold/dcov, released
# under the BSD 3-Clause License.


def _means(X):
    """Calculate matrix-based means.

    Parameters
    ----------
    X : array_like
        Input matrix

    Returns
    -------
    Tuple of np.array
        Vector of row means, vector of column means, and overall matrix
        mean.
    """
    row_mean = np.mean(X, axis=1, keepdims=True)
    col_mean = np.mean(X, axis=0, keepdims=True)
    mean = np.mean(X)

    return row_mean, col_mean, mean


def dcov(A, B):
    """Calculate distance covariance between two distance matrices.

    Parameters
    ----------
    A: array_like
        First distance matrix

    B: array_like
        Second distance matrix

    Returns
    -------
    float
        Sample distance covariance
    """
    assert A.shape == B.shape, RuntimeError(
        "Inputs must have the same cardinality"
    )

    n = len(A)

    A_row_mean, A_col_mean, A_mean = _means(A)
    B_row_mean, B_col_mean, B_mean = _means(B)

    A = A - A_row_mean - A_col_mean + A_mean
    B = B - B_row_mean - B_col_mean + B_mean

    d = 1 / (n**2) * np.sum(np.multiply(A, B))
    return d


def dcor(A, B):
    """Calculate distance correlation between two distance matrices.

    Parameters
    ----------
    A: array_like
        First distance matrix

    B: array_like
        Second distance matrix

    Returns
    -------
    float
        Sample distance correlation
    """
    dcov_AB = dcov(A, B)
    dvar_A = dcov(A, A)
    dvar_B = dcov(B, B)

    if dvar_A > 0 and dvar_B > 0:
        return dcov_AB / np.sqrt(dvar_A * dvar_B)
    else:
        return dcov_AB


#  ╭──────────────────────────────────────────────────────────╮
#  │ Good old helper functions.                               │
#  ╰──────────────────────────────────────────────────────────╯


def check_labels_dim(y):
    """
    Ensures that the labels have at least one dimension.

    This function checks if the input labels `y_val` have at least one dimension.
    If `y_val` is zero-dimensional, it reshapes `y_val` by adding a new dimension,
    making it compatible with functions that expect inputs with at least one dimension.

    Parameters:
    - y_val: A tensor or array-like object representing labels.

    Returns:
    - A tensor or array-like object with at least one dimension.
    """
    try:
        y.dim()
    except IndexError as e:
        print("Reshaping Labels")
        # Workaround for zero-dimensional labels
        y = y.unsqueeze(0)
    return y


def set_if_available(cls, attr_name, value, args):
    """Set an attribute of a class if it is supported."""
    if attr_name in inspect.signature(cls.__init__).parameters:
        args[attr_name] = value


def maybe_normalize_diameter(D):
    """Normalises a distance matrix if possible.

    If the distance matrix is non-zero, i.e. it describes a space with
    more than a single point, we normalise its diameter. Else, we just
    return the matrix unchanged.

    Parameters
    ----------
    D : array_like
        Square distance matrix, assumed to describe pairwise distances
        of a finite metric space.

    Returns
    -------
    array_like
        The distance matrix, normalised by its diameter, i.e. its
        largest non-zero value.
    """
    if (diam := np.max(D)) > 0:
        D /= diam

    return D


def parse_query_string(arguments):
    """Parse query string to permit filtering a data frame.

    Some auxiliary scripts for plotting require collating data from
    different runs, which can potentially occlude relevant patterns
    in a visualisation. To prevent this, the scripts permit passing
    an additional query string, which is used to filter data frames
    in advance. This function parses the command-line arguments for
    the underlying query interface.

    Parameters
    ----------
    arguments : str
        List of arguments to parse into a query. Each entry must be of
        the form `--key=value`, and will be parsed into a valid string
        for the `pd.DataFrame.query` interface.

    Returns
    -------
    str or `None`
        Query string for `pd.DataFrame.query`. Individual query
        conditions are concatenated with a logical "and" operator. If no
        query arguments are provided, will return `None` instead.
    """
    query = None

    for argument in arguments:
        name, value = argument.split("=")
        name = name.replace("--", "")

        # Chain multiple conditions if necessary
        if query is not None:
            query += " & "
        else:
            query = ""

        query += f'{name} == "{value}"'

    return query


def save_pickle(obj, fp):
    with open(fp, "wb") as f:
        pickle.dump(obj, f)


def load_pickle(fp):
    with open(fp, "rb") as f:
        p = pickle.load(f)
    return p


def get_filepath(
    root,
    task,
    n_nodes,
    seed,
    n_graphs=1000,
    p=0.2,
    q=0.5,
):

    task_map = {
        "ASS": "assortativity",
        "MEC": "minedgecut",
        "HAM": "hamiltonian",
        "OES": "oddevensum",
        "PCL": "plantedclique",
    }
    folder = os.path.join(root, task_map[task])
    if task == "HAM":
        filename = f"{task}_n-{n_nodes}.pkl"
    elif task == "OES":
        filename = f"{task}_N-{n_graphs}_n-{n_nodes}_s-{seed}.pkl"
    elif task == "ASS" or task == "PCL":
        filename = f"{task}_N-{n_graphs}_n-{n_nodes}_p-{str(p).replace('.', '-')}_s-{seed}.pkl"
    else:
        filename = f"{task}_N-{n_graphs}_n-{n_nodes}_p-{str(p).replace('.', '-')}_q-{str(q).replace('.', '-')}_s-{seed}.pkl"

    return os.path.join(
        folder,
        filename,
    )


def check_available_devices(force_cpu=False):
    device = "cpu"
    if not force_cpu:
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
    return device


def set_random_seeds(seed):
    random.seed(seed)

    # Set seed for NumPy
    np.random.seed(seed)

    # Set seed for PyTorch
    torch.manual_seed(seed)

    # Set seed for CUDA if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # For deterministic behavior (e.g., convolution layers)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
