import gcip.utils.io as playbook_io
from gcip.modules import module_dict
from gcip.utils.init import get_init_fn


def _load(ckpt_file, Model, **model_args):
    if isinstance(ckpt_file, str):
        playbook_io.print_info(f"Loading {Model} from {ckpt_file}")
        model = Model.load_from_checkpoint(checkpoint_path=ckpt_file, **model_args)
    else:
        model = Model(**model_args)

    return model


def load_model(cfg, preparator, ckpt_file=None):
    init_fn = get_init_fn(cfg_model=cfg.model)

    if cfg.model.name in ['gcip']:
        from .load_gcip import load_gcip
        model = load_gcip(cfg, preparator, init_fn=init_fn, ckpt_file=ckpt_file)

    elif cfg.model.name in ['gnn_pooling']:
        from .load_gnn_pooling import load_gnn_pooling
        model = load_gnn_pooling(cfg, preparator, init_fn=init_fn, ckpt_file=ckpt_file)
    elif cfg.model.name in ['gnn']:
        from gcip.models.model_gnn import GNNLightning
        GNN = module_dict[cfg.model.layer_name]
        module_args = GNN.kwargs(cfg, preparator=preparator)
        module = GNN(**module_args)
        print(module)

        model = _load(ckpt_file=ckpt_file,
                      Model=GNNLightning,
                      preparator=preparator,
                      model=module,
                      init_fn=init_fn,
                      plot=cfg.model.plot)
    else:
        playbook_io.print_warning(f'Model name not valid: {cfg.model.name}')
        assert False

    model.set_optim_config(cfg.optim)
    model.set_optim_config_2(cfg.optim_2)

    return model
