import copy

from gcip.models.model_gnn import GNNLightning
from gcip.modules import module_dict
from .load_model import _load
from gcip.modules.gnn.topk_gnn import TopKGNN


def load_gnn_pooling(cfg, preparator, init_fn=None, ckpt_file=None):
    GNN = module_dict[cfg.model.layer_name]
    cfg_copy = copy.deepcopy(cfg)
    module_params = GNN.kwargs(cfg_copy, preparator=preparator)
    main_gnn = GNN(**module_params)

    cfg_copy = copy.deepcopy(cfg)
    cfg_copy.model.dim_latent = cfg_copy.model.dim_inner
    cfg_copy.model.pooling = None

    # cfg_copy.model.num_layers_pre = 1
    # cfg_copy.model.num_layers = 0
    # cfg_copy.model.num_layers_post = 0
    if cfg.model.hard_pooling:
        cfg_copy.model.num_layers = 0

    module_params = GNN.kwargs(cfg_copy, preparator=preparator)
    gnn = GNN(**module_params)
    pooling = TopKGNN(gnn=gnn,
                      ratio=cfg.model.ratio,
                      min_score=cfg.model.min_score,
                      multiplier=1.0)

    print(pooling)

    model = _load(ckpt_file=ckpt_file,
                  Model=GNNLightning,
                  preparator=preparator,
                  model=main_gnn,
                  pooling_gnn=pooling,
                  init_fn=init_fn)


    return model
