from pathlib import Path
import networkx as nx
import numpy as np
import pandas as pd
import pickle
import torch
import scipy.sparse
from functools import wraps
from time import time
from memory_profiler import memory_usage
from pykeen.triples import TriplesFactory
import igraph as ig
from tqdm import tqdm
from collections import deque
import hashlib

from SEPAL.knowledge_graph import KnowledgeGraph
from SEPAL.dataloader import DataLoader
from SEPAL import SEPAL_DIR, MODEL_NAMES, VALIDATION_FILES
from SEPAL.baselines.dglke.utils import load_dglke_checkpoint_info
from SEPAL.baselines.PBG.utils import load_pbg_checkpoints
from SEPAL.baselines.graphsh.utils import get_grash_checkpoints


def get_checkpoint_info(by_embed_method=False):
    checkpoints = pd.read_parquet(SEPAL_DIR / "checkpoints_sepal.parquet")

    # Set Methods
    checkpoints.loc[checkpoints["partitioning"] == "blocs", "method"] = "SEPAL"
    checkpoints.loc[checkpoints["embed_method"] == "fastrp", "method"] = "FastRP"
    checkpoints.loc[checkpoints["embed_method"] == "random", "method"] = "Random"
    checkpoints.loc[
        checkpoints["partitioning"].isna()
        & (checkpoints["embed_method"].isin(MODEL_NAMES.keys())),
        "method",
    ] = checkpoints["embed_method"].map(MODEL_NAMES)

    checkpoints.loc[checkpoints["emb_model_name"] == "nodepiece", "method"] = (
        checkpoints["emb_model_name"].map(MODEL_NAMES)
        # + " "
        # + checkpoints["embed_method"].map(MODEL_NAMES)
        # + checkpoints["relations_only"].map({True: " (relations only)", False: " (METIS)"})
    )
    checkpoints.loc[
        checkpoints["partitioning"] == "metis",
        "method",
    ] = "SEPAL + METIS"

    # Distinguish SEPAL with and without partitioning
    checkpoints.loc[
        (checkpoints["method"] == "SEPAL") & (checkpoints["num_subgraphs"] == 1),
        "method",
    ] = (
        checkpoints["method"] + " without BLOCS"
    )

    if by_embed_method:
        checkpoints.loc[checkpoints["method"] == "SEPAL", "method"] = (
            checkpoints["embed_method"].map(MODEL_NAMES) + " + " + checkpoints["method"]
        )

    # Merge dim and embed_dim columns
    checkpoints.loc[checkpoints["embed_dim"].isna(), "embed_dim"] = checkpoints["dim"]

    # Fix num_negs_per_pos
    checkpoints.loc[
        checkpoints["method"].isin(
            ["DistMult", "TransE", "RotatE", "MuRE", "TuckER", "HolE"]
        ),
        "num_negs_per_pos",
    ] = (
        checkpoints["negative_sampler_kwargs"]
        .apply(pd.Series)["num_negs_per_pos"]
        .fillna(1)
        .astype("int")
    )

    return checkpoints


def get_full_checkpoint_info(by_embed_method=False):
    checkpoints = get_checkpoint_info(by_embed_method=by_embed_method)
    dglke_checkpoints = load_dglke_checkpoint_info()
    pbg_checkpoints = load_pbg_checkpoints()
    checkpoints_grash = get_grash_checkpoints()
    checkpoints = pd.concat(
        [checkpoints, dglke_checkpoints, pbg_checkpoints, checkpoints_grash]
    ).reset_index(drop=True)
    return checkpoints


def create_graph(data, create_inverse_triples=True):
    print(f"Loading {data} data...")
    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dl = DataLoader(triples_dir)
    tf = dl.get_triples_factory(create_inverse_triples=create_inverse_triples)
    print("Building knowledge graph...")
    graph = KnowledgeGraph(tf)
    return graph


def create_train_graph(data, create_inverse_triples=True):
    print(f"Loading {data} training data...")
    path = SEPAL_DIR / f"datasets/knowledge_graphs/{data}/training_tf.pkl"
    with open(path, "rb") as f:
        tf = pickle.load(f)

    triple_factory = TriplesFactory(
        tf.mapped_triples,
        tf.entity_to_id,
        tf.relation_to_id,
        create_inverse_triples=create_inverse_triples,
    )
    print("Building knowledge graph...")
    graph = KnowledgeGraph(triple_factory)
    return graph


def tf2nx(triples_factory):
    print("Creating NetworkX object...")
    mapped_triples = np.array(
        triples_factory._add_inverse_triples_if_necessary(
            triples_factory.mapped_triples
        )
    )
    G = nx.MultiDiGraph()
    G.add_edges_from(
        list(
            zip(
                mapped_triples[:, 0],
                mapped_triples[:, 2],
                [dict(relation_idx=i) for i in mapped_triples[:, 1]],
            )
        )
    )
    return G


def store_embeddings(
    ctrl, embeddings, relations_embed, embedding_path, checkpoint_path, **kwargs
):
    # Create a dataframe that contains all the parameters and model state info
    config = ctrl.get_config()
    new_checkpoint_info = pd.DataFrame(config, index=[0])

    # Save embeddings
    np.save(embedding_path, embeddings)
    np.save(SEPAL_DIR / f"embeddings/{ctrl.id}_relations.npy", relations_embed.cpu())
    if ctrl.embed_method == "tucker":
        core_tensor = kwargs["core_tensor"]
        np.save(SEPAL_DIR / f"embeddings/{ctrl.id}_tensor.npy", core_tensor.cpu())

    # Load checkpoint_info and add a new line
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info

    # Save checkpoint info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def get_partition_list_from_df(df):
    return [
        df[f"Partition_{k}"][~df[f"Partition_{k}"].isna()].astype(np.int32).values
        for k in range(df.shape[1])
    ]


def keep_only_largest_cc(graph):
    """Remove all connected components from the graph except the largest one.

    Parameters
    ----------
    graph: a KnowledgeGraph object.

    Returns
    -------
    new_graph: a KnowledgeGraph object, corresponding to the largest component of the graph.
    """
    print("Computing largest connected component...")
    n_components, labels = scipy.sparse.csgraph.connected_components(
        graph.adjacency, directed=False, return_labels=True
    )
    node_list = np.where(labels == np.argmax(np.bincount(labels)))[0]

    old_tf = graph.triples_factory
    mask = torch.isin(old_tf.mapped_triples[:, [0, 2]], torch.IntTensor(node_list)).all(
        axis=1
    )
    mapped_triples = old_tf.mapped_triples[mask]

    # Reindex subgraph entities between 0 and n-1
    d = {node_list[i]: i for i in range(len(node_list))}
    mapped_triples[:, [0, 2]] = torch.tensor(
        np.vectorize(d.__getitem__)(mapped_triples[:, [0, 2]])
    )

    # Build the new entity_to_id dictionnary
    old_id_to_entity = {v: k for k, v in old_tf.entity_to_id.items()}
    d_inv = {v: k for k, v in d.items()}
    id_to_entity = {i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
    entity_to_id = {v: k for k, v in id_to_entity.items()}

    # Create triple factory object
    triples_factory = TriplesFactory(
        mapped_triples=mapped_triples,
        entity_to_id=entity_to_id,
        relation_to_id=old_tf.relation_to_id,
        create_inverse_triples=old_tf.create_inverse_triples,
    )

    ## Build knowledge_graph instance
    new_graph = KnowledgeGraph(triples_factory)

    return new_graph


def measure_performance(func):
    @wraps(func)
    def wrapper(time_interval, *args, **kwargs):
        start_time = time()

        peak_memory_usage, result = memory_usage(
            (func, args, kwargs), retval=True, max_usage=True, interval=time_interval
        )

        end_time = time()

        execution_time = end_time - start_time

        return result, execution_time, peak_memory_usage

    return wrapper


def tf2igraph(triples_factory):
    print("Creating igraph object...")
    mapped_triples = np.array(
        triples_factory._add_inverse_triples_if_necessary(
            triples_factory.mapped_triples
        )
    )
    # Create graph
    g = ig.Graph()
    # Add vertices
    n_vertices = int(mapped_triples[:, [0, 2]].max()) + 1
    g.add_vertices(n_vertices)
    # Add edges
    g.add_edges(list(zip(mapped_triples[:, 0].tolist(), mapped_triples[:, 2].tolist())))
    # Add edge weights
    g.es["weight"] = [1 for _ in range(n_vertices)]
    return g


def reorder_outer_subgraphs(graph, outer_subgraphs):
    """Reorder the outer subgraphs by decreasing connectivity with the core subgraph."""
    # Count the number of links between the core subgraph and the outer subgraphs
    num_links = []
    for subgraph in tqdm(outer_subgraphs, desc="Reordering subgraphs"):
        edges = graph.get_mapped_triples()[:, [0, 2]]
        mask_core = np.isin(edges[:, 0], graph.core_subgraph_idx)
        mask_outer = np.isin(edges[:, 1], subgraph)
        num_links.append(
            np.sum(mask_core & mask_outer)
        )  # count the number of triples going from the core to the outer subgraph
    return [outer_subgraphs[i] for i in np.argsort(num_links)[::-1]]


def bfs_path_to_subgraph_csr(graph, source, subgraph_nodes):
    """
    Perform BFS on a graph represented as a csr_matrix to find the path from
    the source node to the closest node in the subgraph.

    Parameters:
        graph (csr_matrix): Sparse adjacency matrix of the graph (unweighted).
        source (int): Index of the source node.
        subgraph_nodes (set of int): Set of nodes in the target subgraph.

    Returns:
        list: The path from the source to the closest subgraph node, or an empty list if no path exists.
    """
    # Queue for BFS: stores (current_node, path_so_far)
    queue = deque([(source, [source])])

    # Track visited nodes
    visited = set()

    while queue:
        current_node, path = queue.popleft()

        # Check if we've reached a node in the subgraph
        if current_node in subgraph_nodes:
            return path

        # Mark the current node as visited
        visited.add(current_node)

        # Get neighbors of the current node (non-zero entries in the sparse matrix row)
        neighbors = graph[current_node].indices

        # Enqueue unvisited neighbors
        for neighbor in neighbors:
            if neighbor not in visited:
                queue.append((neighbor, path + [neighbor]))

    return []  # Return an empty path if no connection exists


def subgraph_hash(ctrl):
    if ctrl.partitioning == "blocs":
        subgraph_parameters = [
            # Dataset parameters
            ctrl.data,
            ctrl.subset,
            # Core selection parameters
            ctrl.core_selection,
            ctrl.core_prop,
            ctrl.core_edge_proportions,
            ctrl.core_node_proportions,
            # Partitioning parameters
            ctrl.partitioning,
            ctrl.subgraph_max_size,
            ctrl.diffusion_stop,
        ]
    else:
        subgraph_parameters = [
            # Dataset parameters
            ctrl.data,
            ctrl.subset,
            # Partitioning parameters
            ctrl.partitioning,
            ctrl.subgraph_max_size,
        ]

    hash = hashlib.sha256(str(subgraph_parameters).encode("utf-8")).hexdigest()
    subgraphs_file = SEPAL_DIR / f"subgraphs_files/{hash}.pkl"
    return subgraphs_file


def get_downstream_wikidb_files():
    classification_folder = SEPAL_DIR / "datasets/wikidb/classification"
    regression_folder = SEPAL_DIR / "datasets/wikidb/regression"

    classification_files = classification_folder.glob("*.parquet")
    regression_files = regression_folder.glob("*.parquet")
    target_files = list(classification_files) + list(regression_files)
    return target_files


def get_cls_wikidb_files():
    classification_folder = SEPAL_DIR / "datasets/wikidb/classification"
    classification_files = classification_folder.glob("*.parquet")
    return list(classification_files)


def get_reg_wikidb_files():
    regression_folder = SEPAL_DIR / "datasets/wikidb/regression"
    regression_files = regression_folder.glob("*.parquet")
    return list(regression_files)


def get_cls_val_files():
    cls_folder = SEPAL_DIR / "datasets/wikidb/classification"
    cls_files = list(cls_folder.glob("*.parquet"))
    cls_val_files = [
        file
        for file in cls_files
        if str(file).split("/")[-1].removesuffix(".parquet") in VALIDATION_FILES
    ]
    return cls_val_files


def get_reg_val_files():
    reg_folder = SEPAL_DIR / "datasets/wikidb/regression"
    reg_files = list(reg_folder.glob("*.parquet"))
    reg_val_files = [
        file
        for file in reg_files
        if str(file).split("/")[-1].removesuffix(".parquet") in VALIDATION_FILES
    ]
    return reg_val_files
