import logging
from pathlib import Path

import hydra
import torch
from omegaconf import DictConfig, OmegaConf

from kge.dataset import get_triple_dataset
from kge.engine import Engine
from kge.loggers import ConsoleLogger, WandBLogger


@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(cfg: DictConfig) -> None:
    if not cfg.add_inverse_triples:
        msg = (
            "add_inverse_triples is set to False. Since the evaluation code uses inverse "
            "relations for head prediction, the results will only be for tail prediction. "
            "Consider running with add_inverse_triples=True for a fair comparison."
        )
        logging.warning(msg)
    dataset = get_triple_dataset(
        cfg.dataset,
        data_folder=Path(cfg.data_folder),
        add_inverse=cfg.add_inverse_triples,
    )
    cfg_update = {"num_entities": dataset.num_entities, "num_relations": dataset.num_relations}
    cfg = OmegaConf.merge(cfg, cfg_update)
    if "gnn_embedding" in cfg.model and "CompGCN" in cfg.model.gnn_embedding._target_:
        # models that need to be instantiated with train_dataset
        gnn_embedding = hydra.utils.instantiate(
            cfg.model.gnn_embedding,
            train_dataset=dataset.train,
        )
        model = hydra.utils.instantiate(cfg.model, gnn_embedding=gnn_embedding)
    elif "Complex" in cfg.model.fusing_function._target_:
        cfg_update = {"rank_dimension": cfg.dimension * 2}
        cfg = OmegaConf.merge(cfg, cfg_update)
        model = hydra.utils.instantiate(cfg.model)
    else:
        model = hydra.utils.instantiate(cfg.model)
    model.to(cfg.device)
    model = torch.compile(model)
    loggers = [ConsoleLogger()]
    if cfg.use_wandb:
        project_name = cfg.wandb.project
        if project_name is None:
            project_name = f"kge_{cfg.dataset}"
        run_name = cfg.wandb.name
        if run_name is None:
            run_name = f"{cfg.dataset}_{model.name}_{cfg.dimension}D"
        loggers.append(
            WandBLogger(
                project=project_name,
                entity=cfg.wandb.entity,
                name=run_name,
                config=OmegaConf.to_container(cfg, resolve=True),
            ),
        )
    engine_cfg = hydra.utils.instantiate(cfg.engine_config)
    engine = Engine(engine_cfg, model=model, dataset=dataset, loggers=loggers)
    engine.train()
    for logger in loggers:
        logger.close()


if __name__ == "__main__":
    main()
