# Use pykeen.triples.TriplesFactory.split() to stratify datasets

import pickle
import torch
import numpy as np
import scipy.sparse
from tqdm import tqdm

from SEPAL.dataloader import DataLoader
from SEPAL import SEPAL_DIR


def remove_coo_duplicates(coo):
    # Get unique entries
    unique_rows, unique_cols = [], []
    unique_data = []
    seen = set()
    for i, (row, col, datum) in tqdm(
        enumerate(zip(coo.row, coo.col, coo.data)), desc="Duplicates removed"
    ):
        if (row, col) not in seen:
            unique_rows.append(row)
            unique_cols.append(col)
            unique_data.append(datum)
            seen.add((row, col))

    # Create new COO matrix without duplicates
    return scipy.sparse.coo_matrix(
        (unique_data, (unique_rows, unique_cols)), shape=coo.shape
    )


def pykeen_split(data):
    # Load triples factory
    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dl = DataLoader(triples_dir)
    tf = dl.get_triples_factory(create_inverse_triples=False)

    # Split the dataset
    training, testing, validation = tf.split([0.9, 0.05, 0.05], random_state=0)

    # Create adjacency
    coo_adjacency = scipy.sparse.coo_matrix(
        (tf.mapped_triples[:, 1], (tf.mapped_triples[:, 0], tf.mapped_triples[:, 2])),
        shape=(tf.num_entities, tf.num_entities),
        dtype="int32",
    )
    # We remove the duplicate entries (otherwise they would be summed when creating the csr/csc matrices).
    # This is not a problem because we only use these matrices to find triples to reconnect the graph.
    coo_adjacency = remove_coo_duplicates(coo_adjacency)
    csc_adjacency = coo_adjacency.tocsc()
    csr_adjacency = csc_adjacency.tocsr()

    ## Reconnect training set if not connected
    connected = False
    while not connected:
        connected = True
        # Get connected components
        mapped_triples = training._add_inverse_triples_if_necessary(
            training.mapped_triples
        )
        edge_index = mapped_triples[:, 0::2].t()
        training_adjacency = scipy.sparse.csr_matrix(
            (torch.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])),
            shape=(training.num_entities, training.num_entities),
            dtype="float32",
        )
        n_components, labels = scipy.sparse.csgraph.connected_components(
            training_adjacency, directed=False, return_labels=True
        )
        # Get largest connected component (lcc)
        lcc_id = np.argmax(np.bincount(labels))
        lcc_entities = torch.tensor(np.where(labels == lcc_id)[0])
        # Iterate over the other components
        for i in tqdm(range(n_components)):
            if i != lcc_id:
                # Get a triple that reconnects the connected component to the lcc
                cc_entities = torch.tensor(np.where(labels == i)[0])

                # Try to find a triple going from the lcc to the cc
                sub_mat = csc_adjacency[:, cc_entities].tocoo()
                mask = np.isin(sub_mat.row, lcc_entities)
                connecting_triples = np.hstack(
                    [
                        np.expand_dims(arr, axis=1)
                        for arr in [
                            sub_mat.row[mask],
                            sub_mat.data[mask],
                            cc_entities[sub_mat.col[mask]],
                        ]
                    ]
                )

                # If no triple was found, try to find a triple going from the cc to the lcc
                if connecting_triples.shape[0] == 0:
                    sub_mat = csr_adjacency[cc_entities].tocoo()
                    mask = np.isin(sub_mat.col, lcc_entities.numpy())
                    connecting_triples = np.hstack(
                        [
                            np.expand_dims(arr, axis=1)
                            for arr in [
                                cc_entities[sub_mat.row[mask]],
                                sub_mat.data[mask],
                                sub_mat.col[mask],
                            ]
                        ]
                    )

                # If a triple was found, reconnect the cc to the lcc with it
                if connecting_triples.shape[0] > 0:
                    connecting_triple = torch.tensor(connecting_triples[0])
                    # Add this triple to the training set
                    training.mapped_triples = torch.vstack(
                        (training.mapped_triples, connecting_triple)
                    )

                    # Remove it from test/validation set
                    if (connecting_triple == testing.mapped_triples).all(axis=1).any():
                        mask = (connecting_triple == testing.mapped_triples).all(axis=1)
                        testing.mapped_triples = testing.mapped_triples[~mask]

                    elif (
                        (connecting_triple == validation.mapped_triples)
                        .all(axis=1)
                        .any()
                    ):
                        mask = (connecting_triple == validation.mapped_triples).all(
                            axis=1
                        )
                        validation.mapped_triples = validation.mapped_triples[~mask]

                    # Add cc_entities to lcc_entities
                    lcc_entities = torch.hstack((lcc_entities, cc_entities))

                # If no such triple exists, do another iteration
                else:
                    connected = False

    ## Save constructed triple factories
    with open(triples_dir / "training_tf.pkl", "wb") as f:
        pickle.dump(training, f)

    with open(triples_dir / "testing_tf.pkl", "wb") as f:
        pickle.dump(testing, f)

    with open(triples_dir / "validation_tf.pkl", "wb") as f:
        pickle.dump(validation, f)
    return


if __name__ == "__main__":
    for data in [
        # "mini_yago3_lcc",
        # "core_yago4",
        # "core_yago4.5",
        "rel_core_yago4",
        "rel_core_yago4.5",
        # "yago3_lcc",
        # "yago4_lcc",
        # "yago4.5_lcc",
        # "yago4_with_ontology",
        # "yago4_with_full_ontology",
        # "yago4.5_with_ontology",
        # "yago4.5_with_full_ontology",
    ]:
        print(f"------------ {data} ------------")
        pykeen_split(data)
