import copy

from gcip.models.model_gcip import GCIPLightning
from gcip.modules import module_dict
from gcip.modules.rl import reward_dict, env_dict
from gcip.modules.rl.graph.policy import GraphActorCritic
from gcip.modules.rl.graph.ppo import GraphPPO
from .load_model import _load


def load_gcip(cfg, preparator, init_fn=None, ckpt_file=None):
    GNN = module_dict[cfg.model.layer_name]
    cfg_copy = copy.deepcopy(cfg)
    cfg_copy.model.dim_latent = 0
    module_params = GNN.kwargs(cfg_copy, preparator=preparator)
    graph_clf = GNN(**module_params)

    print(graph_clf)

    Reward = reward_dict[cfg.reward.name]
    reward_kwargs = Reward.kwargs(cfg=cfg,
                                  preparator=preparator,
                                  graph_clf=graph_clf)

    reward = Reward(**reward_kwargs)

    print(reward)

    Environment = env_dict[cfg.env.name]
    env_kwargs = Environment.kwargs(cfg=cfg,
                                    preparator=preparator,
                                    graph_clf=graph_clf,
                                    reward=reward)

    env = Environment(**env_kwargs)

    cfg_copy = copy.deepcopy(cfg)
    cfg_copy.model.pooling = None

    GNN = module_dict[cfg.model.layer_name]
    module_params = GNN.kwargs(cfg_copy, preparator=preparator)
    gnn = GNN(**module_params)

    policy = GraphActorCritic(gnn=gnn,
                              action_refers_to=cfg.env.action_refers_to,
                              pool_type=cfg.model.pooling,
                              action_distr=env.get_action_distr_name(),
                              act_fn=cfg.model.act,
                              bn=cfg.model.has_bn,
                              dropout=cfg.model.dropout,
                              init_fn=init_fn)

    print(policy)

    ppo = GraphPPO(policy=policy,
                   eps_clip=cfg.model.eps_clip,
                   gamma=cfg.model.gamma,
                   coeff_mse=cfg.model.coeff_mse,
                   coeff_entropy=cfg.model.coeff_entropy)

    model = _load(ckpt_file=ckpt_file,
                  Model=GCIPLightning,
                  preparator=preparator,
                  env=env,
                  graph_clf=graph_clf,
                  ppo=ppo,
                  env_steps=cfg.model.n_steps,
                  warm_up_epochs=cfg.model.warm_up_epochs,
                  ppo_steps=cfg.model.ppo_steps,
                  init_fn=init_fn,
                  plot=cfg.model.plot)

    return model
