import torch.nn.functional as F
import numpy as np
import networkx as nx
from magni.src.modules.metrics import *
from magnipy.magnitude.distances import get_dist

#  ╭──────────────────────────────────────────────────────────╮
#  │ Choosing the right Graph Metric                          │
#  ╰──────────────────────────────────────────────────────────╯

def choose_metric(metric, mode="structure"):
    magnipy_metrics = [
        "Lp",
        "isomap",
        "torch_cdist",
        "braycurtis",
        "canberra",
        "chebyshev",
        "cityblock",
        "correlation",
        "cosine",
        "dice",
        "euclidean",
        "hamming",
        "jaccard",
        "jensenshannon",
        "kulczynski1",
        "mahalanobis",
        "matching",
        "minkowski",
        "rogerstanimoto",
        "russellrao",
        "seuclidean",
        "sokalmichener",
        "sokalsneath",
        "sqeuclidean",
        "yule",
    ]

    if mode == "attributes":
        if metric in magnipy_metrics:

            def get_metric(X, Adj):
                return get_dist(X, metric=metric)

        else:

            def get_metric(X, Adj):
                return lift_attributes(X, metric=metric)

    elif mode == "structure":
        if metric in magnipy_metrics:

            def get_metric(X, Adj):
                return get_dist(X=None, Adj=Adj, metric=metric)

        else:

            def get_metric(X, Adj):
                G = nx.from_numpy_array(Adj)
                return lift_graph(G=G, metric=metric)

    elif mode == "full":
        if metric in magnipy_metrics:

            def get_metric(X, Adj):
                return get_dist(X=X, Adj=Adj, metric=metric)

        else:

            def get_metric(X, Adj):
                G = nx.from_numpy_array(Adj)
                raise NotImplementedError("This metric is not implemented yet")
                # lift_full(X=x, G=G, metric=metric)

    else:
        raise ValueError(
            "mode must be one of 'attributes', 'structure', or 'full'"
        )
    return get_metric


def choose_graph_metric(metric, mode="structure"):
    magnipy_metrics = [
        "Lp",
        "isomap",
        "torch_cdist",
        "braycurtis",
        "canberra",
        "chebyshev",
        "cityblock",
        "correlation",
        "cosine",
        "dice",
        "euclidean",
        "hamming",
        "jaccard",
        "jensenshannon",
        "kulczynski1",
        "mahalanobis",
        "matching",
        "minkowski",
        "rogerstanimoto",
        "russellrao",
        "seuclidean",
        "sokalmichener",
        "sokalsneath",
        "sqeuclidean",
        "yule",
    ]

    if mode == "attributes":
        if metric in magnipy_metrics:

            def get_metric(G):
                X = np.array([G.nodes[i]["feature"] for i in G.nodes])
                return get_dist(X, metric=metric)

        else:

            def get_metric(G):
                X = np.array([G.nodes[i]["feature"] for i in G.nodes])
                return lift_attributes(X, metric=metric)

    elif mode == "structure":
        if metric in magnipy_metrics:

            def get_metric(G):
                Adj = nx.to_numpy_array(G)
                return get_dist(X=None, Adj=Adj, metric=metric)

        else:

            def get_metric(G):
                # G = nx.from_numpy_array(Adj)
                return lift_graph(G=G, metric=metric)

    elif mode == "full":
        if metric in magnipy_metrics:

            def get_metric(G):
                Adj = nx.to_numpy_array(G)
                X = np.array([G.nodes[i]["feature"] for i in G.nodes])
                return get_dist(X=X, Adj=Adj, metric=metric)

        else:

            def get_metric(G):
                # G = nx.from_numpy_array(Adj)
                raise NotImplementedError("This metric is not implemented yet")
                # lift_full(X=x, G=G, metric=metric)

    else:
        raise ValueError(
            "mode must be one of 'attributes', 'structure', or 'full'"
        )
    return get_metric


def to_nx_graph(x, adj):
    G = nx.from_numpy_array(adj)

    for i, feature in enumerate(x):
        G.nodes[i]["feature"] = feature
    return G
