import torch
from pykeen.training import SLCWATrainingLoop
from pykeen.models import DistMult, TransE, TuckER, HolE


optimizers = {
    "SparseAdam": torch.optim.SparseAdam,
    "Adam": torch.optim.Adam,
}


def tucker(ctrl, core_graph):
    model = TuckER(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.interaction.core_tensor.detach(),
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def hole(ctrl, core_graph):
    model = HolE(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def distmult(ctrl, core_graph):
    model = DistMult(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def transe(ctrl, core_graph):
    model = TransE(
        triples_factory=core_graph.triples_factory,
        random_seed=ctrl.seed,
        embedding_dim=ctrl.embed_dim,
    ).to(ctrl.device)

    training_loop = SLCWATrainingLoop(
        model=model,
        triples_factory=core_graph.triples_factory,
        optimizer=optimizers[ctrl.embed_setting.optimizer](
            params=model.get_grad_params(), lr=ctrl.embed_setting.lr
        ),
    )

    losses = training_loop.train(
        triples_factory=core_graph.triples_factory,
        num_epochs=ctrl.num_epochs,
        batch_size=ctrl.batch_size,
    )

    ctrl.embed_setting.training_losses = [losses]

    return (
        model.entity_representations[0]().detach(),
        model.relation_representations[0]().detach(),
    )


def random(ctrl, core_graph):
    return (
        torch.normal(
            mean=ctrl.embed_setting.mean,
            std=ctrl.embed_setting.std,
            size=(core_graph.num_entities, ctrl.embed_dim),
        ),
        torch.normal(
            mean=ctrl.embed_setting.mean,
            std=ctrl.embed_setting.std,
            size=(core_graph.num_relations, ctrl.embed_dim),
        ),
    )
