"""
CLI entrypoint for running experiments.
"""


from args import parse_args
from src.data.loaders import load_datasets
from src.models.create_model import create_model_optim_loss

from src.train.trainer_base_gnn import train_base_gnn
from src.train.trainer_mise_gnn import train_mise_gnn
from src.evaluate.evaluator import evaluate_mise_gnn


if __name__ == "__main__":
    args = parse_args()
    print(args)

    train_data, test_data, train_idx, test_idx = load_datasets(
        args.dataset_name, args.dataset_path,
        args.split_train_name, args.split_test_name,
        args.target_field
    )

    model, optimizer, loss_fn = create_model_optim_loss(args)

    if args.mode  == "mgn":
        train_base_gnn(
            args, model, optimizer, loss_fn,
            train_data, test_data, base_dataset_is_tree=False
        )

    elif args.mode  == "mgn_tree":
        train_base_gnn(
            args, model, optimizer, loss_fn,
            train_data, test_data, base_dataset_is_tree=True
        )

    elif args.mode == "mise_gnn":
        train_mise_gnn(
            args, model, optimizer, loss_fn,
            train_data, test_data, train_idx, test_idx, base_dataset_is_tree=True,
            node_builder_name="basic_nodes", edge_builder_name="basic_edges"
        )

    elif args.mode  == "evaluate_mise_gnn":
        CHECKPOINT_PATH = "/checkpoints/best_state.pt"
        SAVE_DIR = "/inference"

        evaluate_mise_gnn(
            args, model,
            CHECKPOINT_PATH,
            train_data, test_data,
            train_idx, test_idx,
            SAVE_DIR,
            base_dataset_is_tree=True,
            node_builder_name="basic_nodes",
            edge_builder_name="basic_edges"
        )
