import torch
import numpy as np
from pykeen.training import SLCWATrainingLoop
from pykeen.models import DistMult, TransE, RotatE, TuckER, HolE
import h5py
import json

from SEPAL import SEPAL_DIR

optimizers = {
    "SparseAdam": torch.optim.SparseAdam,
    "Adam": torch.optim.Adam,
}


def tucker(ctrl, core_graph):
    model = TuckER(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.interaction.core_tensor.detach(),
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def hole(ctrl, core_graph):
    model = HolE(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def distmult(ctrl, core_graph):
    model = DistMult(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
        loss=ctrl.embed_setting.loss_fn,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
        negative_sampler_kwargs={
            "num_negs_per_pos": ctrl.num_negs_per_pos,
        },
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def transe(ctrl, core_graph):
    model = TransE(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
        negative_sampler_kwargs={
            "num_negs_per_pos": ctrl.num_negs_per_pos,
        },
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def rotate(ctrl, core_graph):
    model = RotatE(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
        negative_sampler_kwargs={
            "num_negs_per_pos": ctrl.num_negs_per_pos,
        },
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def random(ctrl, core_graph):
    return (
        torch.normal(
            mean=ctrl.embed_setting.mean,
            std=ctrl.embed_setting.std,
            size=(core_graph.num_entities, ctrl.embed_dim),
        ),
        torch.normal(
            mean=ctrl.embed_setting.mean,
            std=ctrl.embed_setting.std,
            size=(core_graph.num_relations, ctrl.embed_dim),
        ),
    )


def pbg_precomputed(ctrl, core_graph):
    print("Loading PBG core embeddings...")
    # Parameters
    data = ctrl.data
    core_selection = ctrl.core_selection
    node_prop = ctrl.core_node_proportions
    edge_prop = ctrl.core_edge_proportions
    num_epochs = ctrl.num_epochs

    core_name = f"core_{data}_{core_selection}_{node_prop}_{edge_prop}"

    ## Load embeddings
    embedding_dir = SEPAL_DIR / f"baselines/PBG/model/{core_name}/train_{num_epochs}"
    # Entities
    embeddings_path = embedding_dir / f"embeddings_all_0.v{10*num_epochs}.h5"
    with h5py.File(embeddings_path, "r") as file:
        embeddings = file["embeddings"][()]
    # Relations
    model_path = embedding_dir / f"model.v{10*num_epochs}.h5"
    with h5py.File(model_path, "r") as file:
        rel_embeddings = file["model/relations/0/operator/rhs/diagonals"][()]

    ## Reoder embeddings
    # Load SEPAL training triples factory
    tf = core_graph.triples_factory
    sepal_entity_to_idx = tf.entity_to_id
    sepal_relation_to_idx = tf.relation_to_id

    # Load PBG entity and relation names
    pbg_data_dir = SEPAL_DIR / f"baselines/PBG/data/{core_name}"
    with open(pbg_data_dir / "entity_names_all_0.json") as f:
        entity_names = json.load(f)
    with open(pbg_data_dir / "dynamic_rel_names.json") as f:
        relation_names = json.load(f)
    pbg_entity_to_idx = {k: v for v, k in enumerate(entity_names)}
    pbg_relation_to_idx = {k: v for v, k in enumerate(relation_names)}

    # Get the reorderings
    entity_reordering = [pbg_entity_to_idx[k] for k in sepal_entity_to_idx.keys()]
    relation_reordering = [pbg_relation_to_idx[k] for k in sepal_relation_to_idx.keys()]

    # Reorder
    embeddings = embeddings[entity_reordering]
    rel_embeddings = rel_embeddings[relation_reordering]

    # Add inverse relations
    inv_rel_embeddings = 1 / rel_embeddings
    total_rel_embeddings = np.zeros(
        (2 * rel_embeddings.shape[0], rel_embeddings.shape[1]),
        dtype=rel_embeddings.dtype,
    )
    total_rel_embeddings[0::2] = rel_embeddings
    total_rel_embeddings[1::2] = inv_rel_embeddings

    return torch.from_numpy(embeddings).to(ctrl.device), torch.from_numpy(
        total_rel_embeddings
    ).to(ctrl.device)
