from pathlib import Path
from time import time
import numpy as np
from datetime import datetime
import hashlib
import json
import torch
import pandas as pd
from pykeen.training import SLCWATrainingLoop
from pykeen.models import DistMult, TransE, TuckER, MuRE, HolE


from SEPAL.baselines.fastrp.fastrp import fastrp_wrapper
from SEPAL.dataloader import DataLoader
from SEPAL.utils import create_graph, create_train_graph, measure_performance
from SEPAL import SEPAL_DIR


PYKEEN_MODELS = {  # TODO: add RotatE, ComplEx, QuatE, SE, ConvE, ConvKB?
    "distmult": DistMult,
    "transe": TransE,
    "mure": MuRE,
    "tucker": TuckER,
    "hole": HolE,
}


def fastrp(data="mini_yago3_lcc"):
    print(f"Running FastRP on {data}")
    config = {
        "projection_method": "sparse",
        "input_matrix": "adj",
        "weights": [1.0, 1.0, 7.81, 45.28],
        "normalization": False,
        "dim": 100,
        "alpha": -0.628,
        "C": 1.0,
        "data": data,
        "emb_model_name": "fastrp",
        "embed_method": "fastrp",
    }

    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()

    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"

    config["id"] = embed_id

    config["date"] = datetime.today().strftime("%Y-%m-%d")

    A = create_graph(data).adjacency

    U, config["total_time"], config["mem_usage"] = measure_performance(fastrp_wrapper)(
        1, A, config
    )

    config["weights"] = [config["weights"]]

    # Save embeddings
    np.save(emb_filename, U)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame(config, index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def random(data="mini_yago3_lcc"):
    print(f"Running random on {data}")
    config = {
        "dim": 100,
        "data": data,
        "emb_model_name": "random",
        "mean": 0,
        "std": 0.07,
    }

    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()
    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"
    config["id"] = embed_id
    config["date"] = datetime.today().strftime("%Y-%m-%d")

    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dl = DataLoader(triples_dir)

    start = time()
    U = np.random.normal(
        loc=config["mean"], scale=config["std"], size=(dl.n_entities, config["dim"])
    )
    end = time()
    config["total_time"] = end - start

    # Save embeddings
    np.save(emb_filename, U)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame(config, index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def pykeen_model_training(
    emb_model_name,
    graph,
    seed,
    embed_dim,
    device,
    lr,
    lr_scheduler,
    lr_scheduler_kwargs,
    num_epochs,
    batch_size,
):
    # Initialize embedding model
    model = PYKEEN_MODELS[emb_model_name](
        triples_factory=graph.triples_factory, random_seed=seed, embedding_dim=embed_dim
    ).to(device)
    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=graph.triples_factory,
        optimizer=torch.optim.Adam(params=model.get_grad_params(), lr=lr),
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
    )

    # Train model
    losses = training_loop.train(
        triples_factory=graph.triples_factory,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )
    return model, losses


def run_pykeen_model(
    emb_model_name,
    data="mini_yago3_lcc",
    embed_dim=100,
    num_epochs=40,
    batch_size=512,
    lr=1e-3,
    lr_scheduler=None,
    lr_scheduler_kwargs=None,
    seed=0,
    device="cuda:1",
    subset=None,
):
    print(f"Running {emb_model_name} on {data}")
    # Make config
    config = {
        "data": data,
        "embed_dim": embed_dim,
        "emb_model_name": emb_model_name,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "optimizer": "Adam",
        "lr_scheduler": lr_scheduler,
        "lr_scheduler_kwargs": lr_scheduler_kwargs,
        "seed": seed,
        "subset": subset,
    }

    # Load graph
    if subset == "train":
        graph = create_train_graph(data)
    else:
        graph = create_graph(data)

    # Get config id
    embed_id = hashlib.sha256(
        json.dumps(config, sort_keys=True).encode("ascii")
    ).hexdigest()
    emb_filename = SEPAL_DIR / f"embeddings/{embed_id}.npy"
    config["id"] = embed_id
    config["date"] = datetime.today().strftime("%Y-%m-%d")

    (
        (model, config["training_losses"]),
        config["total_time"],
        config["mem_usage"],
    ) = measure_performance(pykeen_model_training)(
        time_interval=1,
        emb_model_name=emb_model_name,
        graph=graph,
        seed=seed,
        embed_dim=embed_dim,
        device=device,
        lr=lr,
        lr_scheduler=lr_scheduler,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        num_epochs=num_epochs,
        batch_size=batch_size,
    )
    config["embed_method"] = config["emb_model_name"]

    # Save embeddings for entities and relations
    embeddings = model.entity_representations[0]().detach().cpu().numpy()
    np.save(emb_filename, embeddings)
    if emb_model_name == "mure":
        biais_1 = model.entity_representations[1]().detach().cpu().numpy()
        biais_2 = model.entity_representations[2]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_biais_1.npy", biais_1)
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_biais_2.npy", biais_2)
        relations_embed_0 = model.relation_representations[0]().detach().cpu().numpy()
        relations_embed_1 = model.relation_representations[1]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations_0.npy", relations_embed_0)
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations_1.npy", relations_embed_1)
    elif emb_model_name == "tucker":
        relations_embed = model.relation_representations[0]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations.npy", relations_embed)
        core_tensor = model.interaction.core_tensor.detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_tensor.npy", core_tensor)
    else:
        relations_embed = model.relation_representations[0]().detach().cpu().numpy()
        np.save(SEPAL_DIR / f"embeddings/{embed_id}_relations.npy", relations_embed)

    # Save checkpoint
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    new_checkpoint_info = pd.DataFrame([config], index=[0])
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


if __name__ == "__main__":
    fastrp("yago4_with_full_ontology")
