import os
from pathlib import Path
import sys
import toml

import torch
import torch.nn.functional as F

from exp_classification import NodeClassificationExp
from factory import (
    make_dataset,
    make_model,
    make_optimizer,
    make_reporter,
    make_lossfn,
)


if __name__ == "__main__":
    config_name = sys.argv[1]
    with open(config_name, mode="r") as f:
        cfg = toml.load(f)

    if "reporter" in cfg:
        reporter = make_reporter(cfg["reporter"], cfg)
    else:
        reporter = None

    cfg["model"]["eigen_path"] = (
        (
            Path(os.getcwd())
            / Path(cfg["dataset"]["root"])
            / Path(cfg["dataset"]["name"])
            / Path("eigen")
        )
        .as_posix()
        .lower()
    )
    cfg["model"]["eigen_cache"] = True

    dataset = make_dataset(cfg["dataset"])
    model = make_model(cfg["model"], dataset)
    optimizer = make_optimizer(cfg["optimizer"], model)
    loss_fn, eval_func = make_lossfn(cfg["lossfn"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"running on device: {device}")

    exp = NodeClassificationExp(
        model,
        loss_fn,
        optimizer,
        dataset,
        reporter,
        device,
        model_cfg=cfg["model"],
    )
    exp.train()