import torch
from net.modules import encode, core, decode, transform, transform_corr


class EncodeDecode(torch.nn.Module):
    def __init__(self, nodes_shape, edges_shape, globals_shape, output_shape, cfg):
        super(EncodeDecode, self).__init__()
        # encode net
        self._encode = encode(nodes_shape, edges_shape, globals_shape)
        # core net
        if cfg.flag_gn_iteration:
            self._core_nets = torch.nn.ModuleList([core() for i in range(cfg.gn_iteration)])
        else:
            self._core = core()
        # decode net
        self._decode = decode()
        # output net
        self._transform = transform(output_shape)
        self._transform_corr = transform_corr(output_shape)

        self._stored_edges = None
        self._cfg = cfg
        #self._stored_globals = None

    def forward(self, nodes, edges, globals, senders, receivers, nodes_index, edges_index, is_train=True, it=None):
        if (not is_train) and self._cfg.flag_encode_store and (it != None):
            if it == 0:
                nodes_e, edges_e, globals_e = self._encode(nodes, edges, globals, it)
                self._stored_edges = edges_e
                #self._stored_globals = globals_e
            else:
                nodes_e, _, globals_e = self._encode(nodes, edges, globals, it)
                edges_e = self._stored_edges
                #globals_e = self._stored_globals
        else:
            nodes_e, edges_e, globals_e = self._encode(nodes, edges, globals, it)

        if self._cfg.flag_gn_iteration:
            for i in range(self._cfg.gn_iteration):
                nodes_c, edges_c, globals_c = self._core_nets[i](nodes_e, edges_e, globals_e, senders, receivers, nodes_index,
                                                         edges_index)
                nodes_e, edges_e, globals_e = nodes_e + nodes_c, edges_e + edges_c, globals_e + globals_c
            nodes_d, edges_d, globals_d = self._decode(nodes_e, edges_e, globals_e)
        else:
            nodes_c, edges_c, globals_c = self._core(nodes_e, edges_e, globals_e, senders, receivers, nodes_index, edges_index)
            nodes_d, edges_d, globals_d = self._decode(nodes_c, edges_c, globals_c)

        # if cfg.flag_res_connect:
        #     out = self._transform(nodes_d) + nodes[..., 0:out_dim]
        # else:
        #     out = self._transform(nodes_d)

        output_corr, output_lambda_v = self._transform_corr(nodes_d, edges_d)
        return output_corr, output_lambda_v, globals_d
