import torch
import torch_scatter
from my_config import config as cfg


def make_mlp_model(input_size):
    layer1 = torch.nn.Linear(input_size, cfg.latent_size)
    if cfg.flag_all_relu:
        layers = [torch.nn.ReLU(), torch.nn.Linear(cfg.latent_size, cfg.latent_size)] * (cfg.num_layers - 1)
    else:
        layers = [torch.nn.Linear(cfg.latent_size, cfg.latent_size)] * (cfg.num_layers - 1)

    return torch.nn.Sequential(layer1, *layers, torch.nn.ReLU())


def do_normalization(input):
    assert not (input.shape[0] != 1 and cfg.flag_layer_normalization == True), \
        "when lay normalization is true, batch size should be 1"
    assert not (cfg.flag_batch_normalization == True and cfg.flag_layer_normalization == True), \
        "only one type of normalization can be used"
    if cfg.flag_layer_normalization:
        # https://pytorch.org/docs/stable/nn.html#layernorm
        # Unlike Batch Normalization and Instance Normalization,
        # which applies scalar scale and bias for each entire channel/plane with the affine option,
        # Layer Normalization applies per-element scale and bias with elementwise_affine.
        # scale and bias are learnt, mean and variance are computed
        # input is the normalized_shape
        if cfg.flag_gpu:
            input = torch.nn.LayerNorm(input.size()[-1]).cuda()(input)
        else:
            input = torch.nn.LayerNorm(input.size()[-1])(input)

    elif cfg.flag_batch_normalization:
        batch_size = input.shape[0]
        # num_features – C from an expected input of size (N, C, L) or L from input of size (N, L)
        # So we do transpose or just reshape, scale and bias are learnt, mean and variance are stored
        input = input.view(-1, cfg.latent_size)
        #input = input.transpose(1, 2)
        if cfg.flag_gpu:
            input = torch.nn.BatchNorm1d(cfg.latent_size).cuda()(input)
        else:
            input = torch.nn.BatchNorm1d(cfg.latent_size)(input)
        input = input.view(batch_size, -1, cfg.latent_size)
        #input = input.transpose(1, 2)
    return input


class encode(torch.nn.Module):
    def __init__(self, nodes_shape, edges_shape, globals_shape):
        super(encode, self).__init__()
        self.nodes_mlp = make_mlp_model(nodes_shape[1])
        self.edges_mlp = make_mlp_model(edges_shape[1])
        self.globals_mlp = make_mlp_model(globals_shape[1])

    def forward(self, nodes, edges, globals, iter=None):
        nodes_ = self.nodes_mlp(nodes)
        #print("encode after mlp", nodes_[0][29][0:3])
        nodes_ = do_normalization(nodes_)
        #print("encode after norm", nodes_[0][29][0:3])
        if cfg.flag_encode_store and (iter != 0) and (iter != None):
            edges_ = None
        else:
            edges_ = self.edges_mlp(edges)
            edges_ = do_normalization(edges_)

        if cfg.flag_global:
            globals_ = self.globals_mlp(globals)
            globals_ = do_normalization(globals_)
        else:
            globals_ = globals

        return nodes_, edges_, globals_


class core(torch.nn.Module):
    def __init__(self, nodes_shape=None, edges_shape=None, globals_shape=None):
        super(core, self).__init__()
        if cfg.flag_global:
            self.nodes_mlp = make_mlp_model(cfg.latent_size * 4)
            self.edges_mlp = make_mlp_model(cfg.latent_size * 4)
            self.globals_mlp = make_mlp_model(cfg.latent_size * 3)
        else:
            self.nodes_mlp = make_mlp_model(cfg.latent_size * 3)
            self.edges_mlp = make_mlp_model(cfg.latent_size * 3)
            self.globals_mlp = make_mlp_model(cfg.latent_size)

    def _update_edges(self, nodes, edges, senders, receivers, globals, edges_index):
        """
        First broadcast receiver nodes, then sender notes, then global features
        """
        batch_size = edges.shape[0]
        edges_num = edges.shape[1]

        # broadcast the receiver nodes
        res = torch.zeros((batch_size, edges_num, cfg.latent_size))
        if cfg.flag_gpu:
            res = res.cuda()
        for j in range(batch_size):
            res[j, ...] = nodes[j][receivers[j]]

        # broadcast the sender nodes
        sen = torch.zeros((batch_size, edges_num, cfg.latent_size))
        if cfg.flag_gpu:
            sen = sen.cuda()
        for j in range(batch_size):
            sen[j, ...] = nodes[j][senders[j]]

        # broadcast the globals feature with has already broadcast on nodes
        if cfg.flag_global:
            glo = torch.zeros((batch_size, edges_num, cfg.latent_size))
            if cfg.flag_gpu:
                glo = glo.cuda()
            for j in range(batch_size):
                glo[j, ...] = globals[j][edges_index[j]]

        else:
            glo = None

        # concatenate the input
        if cfg.flag_global:
            input_ = torch.cat((edges, res, sen, glo), -1)
        else:
            input_ = torch.cat((edges, res, sen), -1)

        edges_ = self.edges_mlp(input_)
        edges_ = do_normalization(edges_)
        return edges_

    def _update_nodes(self, nodes, edges, senders, receivers, globals, nodes_index):
        """
        First aggregate edges, then broadcast global features
        """

        # Aggregate the receiver edges
        batch_size = nodes.shape[0]
        nodes_num = nodes.shape[1]
        res = torch.zeros((batch_size, nodes_num, cfg.latent_size))
        sen = torch.zeros((batch_size, nodes_num, cfg.latent_size))

        if cfg.flag_gpu:
            res = res.cuda()
            sen = sen.cuda()

        for j in range(batch_size):
            res[j] = torch_scatter.scatter_add(edges[j], receivers[j], dim=0, out=res[j])
        for j in range(batch_size):
            sen[j] = torch_scatter.scatter_add(edges[j], senders[j], dim=0, out=sen[j])

        # broadcast the globals feature
        if cfg.flag_global:
            glo = torch.zeros((batch_size, nodes_num, cfg.latent_size))
            if cfg.flag_gpu:
                glo = glo.cuda()
            for j in range(batch_size):
                glo[j, ...] = globals[j][nodes_index[j]]
        else:
            glo = None

        # concatenate the input
        if cfg.flag_global:
            input_ = torch.cat((nodes, res, sen, glo), -1)
        else:
            input_ = torch.cat((nodes, res, sen), -1)
        nodes_ = self.nodes_mlp(input_)
        nodes_ = do_normalization(nodes_)
        return nodes_

    def _update_globals(self, nodes, edges, globals, nodes_index, edges_index):
        batch_size = globals.shape[0]
        globals_num = globals.shape[1]
        if cfg.flag_global:
            glo = globals
            nod = torch.zeros((batch_size, globals_num, cfg.latent_size))
            edg = torch.zeros((batch_size, globals_num, cfg.latent_size))
            if cfg.flag_gpu:
                nod = nod.cuda()
                edg = edg.cuda()
            for j in range(batch_size):
                nod[j] = torch_scatter.scatter_add(nodes[j], nodes_index[j], dim=0, out=nod[j])
                edg[j] = torch_scatter.scatter_add(edges[j], edges_index[j], dim=0, out=edg[j])

            input_ = torch.cat((glo, nod, edg), -1)
            globals_ = self.globals_mlp(input_)
            globals_ = do_normalization(globals_)
        else:
            globals_ = globals
        return globals_

    def forward(self, nodes, edges, globals, senders, receivers, nodes_index, edges_index):
        edges_ = self._update_edges(nodes, edges, senders, receivers,  globals, edges_index)
        nodes_ = self._update_nodes(nodes, edges_, senders, receivers, globals, nodes_index)
        globals_ = self._update_globals(nodes_, edges_, globals, nodes_index, edges_index)

        return nodes_, edges_, globals_


class decode(torch.nn.Module):
    def __init__(self, nodes_shape=None, edges_shape=None, globals_shape=None):
        super(decode, self).__init__()
        self.nodes_mlp = make_mlp_model(cfg.latent_size)
        self.edges_mlp = make_mlp_model(cfg.latent_size)
        self.globals_mlp = make_mlp_model(cfg.latent_size)

    def forward(self, nodes, edges, globals):
        nodes_ = self.nodes_mlp(nodes)
        nodes_ = do_normalization(nodes_)
        if not cfg.flag_decode_close:
            edges_ = self.edges_mlp(edges)
            edges_ = do_normalization(edges_)
            if cfg.flag_global:
                globals_ = self.globals_mlp(globals)
                globals_ = do_normalization(globals_)
            else:
                globals_ = globals
        else:
            edges_ = None
            globals_ = None

        return nodes_, edges_, globals_


class transform(torch.nn.Module):
    def __init__(self, output_shape):
        super(transform, self).__init__()
        self.output_mlp = torch.nn.Linear(cfg.latent_size, output_shape)

    def forward(self, nodes):
        output = self.output_mlp(nodes)
        return output


class transform_corr(torch.nn.Module):
    def __init__(self, output_shape):
        super(transform_corr, self).__init__()
        self.output_corr_mlp = torch.nn.Linear(cfg.latent_size, output_shape)
        self.output_lambda_mlp = torch.nn.Linear(cfg.latent_size, output_shape)

    def forward(self, nodes, edges):
        output_corr = self.output_corr_mlp(nodes)
        output_corr = torch.tanh(output_corr)
        output_lambda_v = self.output_lambda_mlp(edges)
        output_lambda_v = torch.tanh(output_lambda_v)
        return output_corr, output_lambda_v