from pathlib import Path
from time import time
import numpy as np
from datetime import datetime
import hashlib
import json
import torch
import pandas as pd
from pykeen.training import SLCWATrainingLoop
from pykeen.models import DistMult, TransE, TuckER, MuRE, HolE, NodePiece, RotatE


from SEPAL.baselines.fastrp.fastrp import fastrp_wrapper
from SEPAL.dataloader import DataLoader
from SEPAL.utils import create_graph, create_train_graph, measure_performance
from SEPAL import SEPAL_DIR


PYKEEN_MODELS = {  # TODO: add ComplEx, QuatE, SE, ConvE, ConvKB?
    "distmult": DistMult,
    "transe": TransE,
    "rotate": RotatE,
    "mure": MuRE,
    "tucker": TuckER,
    "hole": HolE,
    "nodepiece": NodePiece,
    "pbg_precomputed": DistMult,
}


def fastrp(data="mini_yago3_lcc"):
    print(f"Running FastRP on {data}")
    config = {
        "projection_method": "sparse",
        "input_matrix": "adj",
        "weights": [1.0, 1.0, 7.81, 45.28],
        "normalization": False,
        "dim": 100,
        "alpha": -0.628,
        "C": 1.0,
        "data": data,
        "emb_model_name": "fastrp",
        "embed_method": "fastrp",
    }

    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()

    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"

    config["id"] = embed_id

    config["date"] = datetime.today().strftime("%Y-%m-%d")

    A = create_graph(data).adjacency

    U, config["total_time"], config["mem_usage"] = measure_performance(fastrp_wrapper)(
        1, A, config
    )

    config["weights"] = [config["weights"]]

    # Save embeddings
    np.save(emb_filename, U)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame(config, index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def random(data="mini_yago3_lcc"):
    print(f"Running random on {data}")
    config = {
        "dim": 100,
        "data": data,
        "emb_model_name": "random",
        "mean": 0,
        "std": 0.07,
    }

    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()
    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"
    config["id"] = embed_id
    config["date"] = datetime.today().strftime("%Y-%m-%d")

    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dl = DataLoader(triples_dir)

    start = time()
    U = np.random.normal(
        loc=config["mean"], scale=config["std"], size=(dl.n_entities, config["dim"])
    )
    end = time()
    config["total_time"] = end - start

    # Save embeddings
    np.save(emb_filename, U)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame(config, index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def pykeen_model_training(
    emb_model_name,
    graph,
    seed,
    embed_dim,
    device,
    lr,
    lr_scheduler,
    lr_scheduler_kwargs,
    negative_sampler,
    negative_sampler_kwargs,
    loss,
    num_epochs,
    batch_size,
):
    # Initialize embedding model
    model = PYKEEN_MODELS[emb_model_name](
        triples_factory=graph.triples_factory,
        random_seed=seed,
        embedding_dim=embed_dim,
        loss=loss,
    ).to(device)
    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=graph.triples_factory,
        optimizer=torch.optim.Adam(params=model.get_grad_params(), lr=lr),
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        negative_sampler=negative_sampler,
        negative_sampler_kwargs=negative_sampler_kwargs,
    )

    # Train model
    losses = training_loop.train(
        triples_factory=graph.triples_factory,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )
    return model, losses


def run_pykeen_model(
    emb_model_name,
    data="mini_yago3_lcc",
    embed_dim=100,
    num_epochs=40,
    batch_size=512,
    lr=1e-3,
    lr_scheduler=None,
    lr_scheduler_kwargs=None,
    negative_sampler=None,
    negative_sampler_kwargs=None,
    loss=None,
    seed=0,
    device="cuda:1",
    subset=None,
):
    # Make config
    config = {
        "data": data,
        "embed_dim": embed_dim,
        "emb_model_name": emb_model_name,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "optimizer": "Adam",
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_kwargs": lr_scheduler_kwargs,
        "negative_sampler": negative_sampler,
        "negative_sampler_kwargs": negative_sampler_kwargs,
        "loss_fn": loss,
        "seed": seed,
        "subset": subset,
    }

    # Load graph
    if subset == "train":
        graph = create_train_graph(data)
    else:
        graph = create_graph(data)

    # Get config id
    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()
    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"
    config["id"] = embed_id
    config["date"] = datetime.today().strftime("%Y-%m-%d")

    (
        (model, config["training_losses"]),
        config["total_time"],
        config["mem_usage"],
    ) = measure_performance(pykeen_model_training)(
        time_interval=1,
        emb_model_name=emb_model_name,
        graph=graph,
        seed=seed,
        embed_dim=embed_dim,
        device=device,
        lr=lr,
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        negative_sampler=negative_sampler,
        negative_sampler_kwargs=negative_sampler_kwargs,
        loss=loss,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )
    config["embed_method"] = config["emb_model_name"]

    # Save embeddings for entities and relations
    embeddings = model.entity_representations[0]().detach().cpu().numpy()
    np.save(emb_filename, embeddings)
    if emb_model_name == "mure":
        biais_1 = model.entity_representations[1]().detach().cpu().numpy()
        biais_2 = model.entity_representations[2]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_biais_1.npy", biais_1)
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_biais_2.npy", biais_2)
        relations_embed_0 = model.relation_representations[0]().detach().cpu().numpy()
        relations_embed_1 = model.relation_representations[1]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations_0.npy", relations_embed_0)
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations_1.npy", relations_embed_1)
    elif emb_model_name == "tucker":
        relations_embed = model.relation_representations[0]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations.npy", relations_embed)
        core_tensor = model.interaction.core_tensor.detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_tensor.npy", core_tensor)
    else:
        relations_embed = model.relation_representations[0]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations.npy", relations_embed)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame([config], index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def run_nodepiece(
    interaction,
    graph,
    seed,
    embed_dim,
    device,
    lr,
    lr_scheduler,
    lr_scheduler_kwargs,
    num_epochs,
    batch_size,
    relations_only,
):
    if relations_only:
        model = NodePiece(
            triples_factory=graph.triples_factory,
            tokenizers="RelationTokenizer",
            num_tokens=12,  # 12 tokens per relation
            embedding_dim=embed_dim,
            interaction=interaction,
            aggregation="mlp",
            random_seed=seed,
        ).to(device)
    else:
        model = NodePiece(
            triples_factory=graph.triples_factory,
            tokenizers=["MetisAnchorTokenizer", "RelationTokenizer"],
            num_tokens=[20, 12],  # 20 anchors per node in for the Metis strategy
            embedding_dim=embed_dim,
            interaction=interaction,
            tokenizers_kwargs=[
                dict(
                    num_partitions=20,
                    device="cpu",  # METIS on cpu tends to be faster
                    selection="MixtureAnchorSelection",  # we can use any anchor selection strategy here
                    selection_kwargs=dict(
                        selections=["degree", "random"],
                        ratios=[0.5, 0.5],
                        num_anchors=1000,  # overall, we will have 20 * 1000 = 20000 anchors
                    ),
                    searcher="SparseBFSSearcher",  # a new efficient anchor searcher
                    searcher_kwargs=dict(
                        max_iter=5  # each node will be tokenized with anchors in the 5-hop neighborhood
                    ),
                ),
                dict(),
            ],
            aggregation="mlp",
            random_seed=seed,
        )

    # Load model on device
    model = model.to(device)

    # Initialize training loop
    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=graph.triples_factory,
        optimizer=torch.optim.Adam(params=model.get_grad_params(), lr=lr),
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
    )

    # Train model
    losses = training_loop.train(
        triples_factory=graph.triples_factory,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )

    return model, losses


def nodepiece_wrapper(
    data="mini_yago3_lcc",
    interaction="distmult",
    relations_only=False,
    embed_dim=100,
    num_epochs=40,
    batch_size=512,
    lr=1e-3,
    lr_scheduler=None,
    lr_scheduler_kwargs=None,
    seed=0,
    device="cuda:1",
    subset=None,
):
    print(f"Running NodePiece on {data}")
    # Make config
    config = {
        "data": data,
        "embed_dim": embed_dim,
        "emb_model_name": "nodepiece",
        "interaction": interaction,
        "relations_only": relations_only,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "optimizer": "Adam",
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_kwargs": lr_scheduler_kwargs,
        "seed": seed,
        "subset": subset,
    }

    # Load graph
    if subset == "train":
        graph = create_train_graph(data)
    else:
        graph = create_graph(data)

    # Get config id
    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()
    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"
    config["id"] = embed_id
    config["date"] = datetime.today().strftime("%Y-%m-%d")

    (
        (model, config["training_losses"]),
        config["total_time"],
        config["mem_usage"],
    ) = measure_performance(run_nodepiece)(
        time_interval=1,
        interaction=interaction,
        relations_only=relations_only,
        graph=graph,
        seed=seed,
        embed_dim=embed_dim,
        device=device,
        lr=lr,
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )
    config["embed_method"] = config["interaction"]

    # Load model on cpu
    model = model.to("cpu")

    # Save model state dict
    torch.save(model.state_dict(), SEPAL_DIR / f"models/{embed_id}.pt")

    """ 
    # Save embeddings for entities and relations
    embeddings = model.entity_representations[0]().detach().numpy()
    np.save(emb_filename, embeddings)
    relations_embed = model.relation_representations[0]().detach().numpy()
    np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations.npy", relations_embed)
    """

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame([config], index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def get_nodepiece_embeddings_from_model():
    # Load all models saved in SEPAL_DIR/models
    model_dir = SEPAL_DIR / "models"
    for model_path in model_dir.iterdir():
        # Get model id
        model_id = model_path.stem
        # Check if embedding file already exists
        emb_filename = SEPAL_DIR / f"embeddings/{model_id}.npy"
        if not emb_filename.is_file():
            # Load checkpoint info
            checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
            checkpoint_info = pd.read_parquet(checkpoint_path)
            # Get config
            config = checkpoint_info[checkpoint_info["id"] == model_id].to_dict(
                orient="records"
            )[0]
            # Load graph
            if config["subset"] == "train":
                graph = create_train_graph(config["data"])
            else:
                graph = create_graph(config["data"])
            # Build model
            relations_only = config["relations_only"]
            if relations_only:
                model = NodePiece(
                    triples_factory=graph.triples_factory,
                    tokenizers="RelationTokenizer",
                    num_tokens=12,  # 12 tokens per relation
                    embedding_dim=int(config["embed_dim"]),
                    interaction=config["interaction"],
                    aggregation="mlp",
                    random_seed=int(config["seed"]),
                )
            else:
                model = NodePiece(
                    triples_factory=graph.triples_factory,
                    tokenizers=["MetisAnchorTokenizer", "RelationTokenizer"],
                    num_tokens=[
                        20,
                        12,
                    ],  # 20 anchors per node in for the Metis strategy
                    embedding_dim=int(config["embed_dim"]),
                    interaction=config["interaction"],
                    tokenizers_kwargs=[
                        dict(
                            num_partitions=20,
                            device="cpu",  # METIS on cpu tends to be faster
                            selection="MixtureAnchorSelection",  # we can use any anchor selection strategy here
                            selection_kwargs=dict(
                                selections=["degree", "random"],
                                ratios=[0.5, 0.5],
                                num_anchors=1000,  # overall, we will have 20 * 1000 = 20000 anchors
                            ),
                            searcher="SparseBFSSearcher",  # a new efficient anchor searcher
                            searcher_kwargs=dict(
                                max_iter=5  # each node will be tokenized with anchors in the 5-hop neighborhood
                            ),
                        ),
                        dict(),
                    ],
                    aggregation="mlp",
                    random_seed=int(config["seed"]),
                )
            # Load model state dict
            model.load_state_dict(torch.load(model_path))
            # Save embeddings for entities and relations
            embeddings = model.entity_representations[0]().detach().numpy()
            np.save(emb_filename, embeddings)
            relations_embed = model.relation_representations[0]().detach().numpy()
            np.save(SEPAL_DIR / f"embeddings/{model_id}_relations.npy", relations_embed)
    return


if __name__ == "__main__":
    for dataset in ["mini_yago3_lcc", "yago3_lcc"]:
        run_pykeen_model(
            "rotate",
            data=dataset,
            embed_dim=50,
            device="cuda:1",
            subset=None,
        )
