import copy
from functools import partial

import torch
import torch.nn as nn
from torch.nn import functional as F

import dgl
import dgl.function as fn
import dgl.nn as dglnn


class MLP(nn.Module):
    def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        layer = nn.Linear(hidden, out_feats)
        nn.init.normal_(layer.weight, std=0.1)
        nn.init.zeros_(layer.bias)
        self.layers.append(nn.Linear(in_feats, hidden))
        if num_layers > 2:
            for i in range(1, num_layers - 1):
                layer = nn.Linear(hidden, hidden)
                nn.init.normal_(layer.weight, std=0.1)
                nn.init.zeros_(layer.bias)
                self.layers.append(layer)
        layer = nn.Linear(hidden, out_feats)
        nn.init.normal_(layer.weight, std=0.1)
        nn.init.zeros_(layer.bias)
        self.layers.append(layer)

    def forward(self, x):
        for l in range(len(self.layers) - 1):
            x = self.layers[l](x)
            x = F.relu(x)
        x = self.layers[-1](x)
        return x


class PrepareLayer(nn.Module):
    """
    Generate edge feature for the model input preparation:
    as well as do the normalization work.
    Parameters
    ==========
    node_feats : int
        Number of node features

    stat : dict
        dictionary which represent the statistics needed for normalization
    """

    def __init__(self, node_feats, stat):
        super(PrepareLayer, self).__init__()
        self.node_feats = node_feats
        # stat {'median':median,'max':max,'min':min}
        self.stat = stat

    def normalize_input(self, node_feature):
        return (node_feature - self.stat["median"]) * (
            2 / (self.stat["max"] - self.stat["min"])
        )

    def forward(self, g, node_feature):
        with g.local_scope():
            node_feature = self.normalize_input(node_feature)
            g.ndata["feat"] = node_feature  # Only dynamic feature
            g.apply_edges(fn.u_sub_v("feat", "feat", "e"))
            edge_feature = g.edata["e"]
            return node_feature, edge_feature


class InteractionNet(nn.Module):
    """
    Simple Interaction Network
    One Layer interaction network for stellar multi-body problem simulation,
    it has the ability to simulate number of body motion no more than 12
    Parameters
    ==========
    node_feats : int
        Number of node features

    stat : dict
        Statistcics for Denormalization
    """

    def __init__(self, node_feats, stat):
        super(InteractionNet, self).__init__()
        self.node_feats = node_feats
        self.stat = stat
        edge_fn = partial(MLP, num_layers=5, hidden=150)
        node_fn = partial(MLP, num_layers=2, hidden=100)

        self.in_layer = InteractionLayer(
            node_feats - 3,  # Use velocity only
            node_feats,
            out_node_feats=2,
            out_edge_feats=50,
            edge_fn=edge_fn,
            node_fn=node_fn,
            mode="n_n",
        )

    # Denormalize Velocity only
    def denormalize_output(self, out):
        return (
            out * (self.stat["max"][3:5] - self.stat["min"][3:5]) / 2
            + self.stat["median"][3:5]
        )

    def forward(self, g, n_feat, e_feat, global_feats, relation_feats):
        with g.local_scope():
            out_n, out_e = self.in_layer(
                g, n_feat, e_feat, global_feats, relation_feats
            )
            out_n = self.denormalize_output(out_n)
            return out_n, out_e


class InteractionLayer(nn.Module):
    """
    Implementation of single layer of interaction network
    Parameters
    ==========
    in_node_feats : int
        Number of node features

    in_edge_feats : int
        Number of edge features

    out_node_feats : int
        Number of node feature after one interaction

    out_edge_feats : int
        Number of edge features after one interaction

    global_feats : int
        Number of global features used as input

    relate_feats : int
        Feature related to the relation between object themselves

    edge_fn : torch.nn.Module
        Function to update edge feature in message generation

    node_fn : torch.nn.Module
        Function to update node feature in message aggregation

    mode : str
        Type of message should the edge carry
        nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.
        n_n : [src_feat-edge_feat] node feature subtract from each other.
    """

    def __init__(
        self,
        in_node_feats,
        in_edge_feats,
        out_node_feats,
        out_edge_feats,
        global_feats=1,
        relate_feats=1,
        edge_fn=nn.Linear,
        node_fn=nn.Linear,
        mode="nne",
    ):  # 'n_n'
        super(InteractionLayer, self).__init__()
        self.in_node_feats = in_node_feats
        self.in_edge_feats = in_edge_feats
        self.out_edge_feats = out_edge_feats
        self.out_node_feats = out_node_feats
        self.mode = mode
        # MLP for message passing
        input_shape = (
            2 * self.in_node_feats + self.in_edge_feats
            if mode == "nne"
            else self.in_edge_feats + relate_feats
        )
        self.edge_fn = edge_fn(
            input_shape, self.out_edge_feats
        )  # 50 in IN paper

        self.node_fn = node_fn(
            self.in_node_feats + self.out_edge_feats + global_feats,
            self.out_node_feats,
        )

    # Should be done by apply edge
    def update_edge_fn(self, edges):
        x = torch.cat(
            [edges.src["feat"], edges.dst["feat"], edges.data["feat"]], dim=1
        )
        ret = F.relu(self.edge_fn(x)) if self.mode == "nne" else self.edge_fn(x)
        return {"e": ret}

    # Assume agg comes from build in reduce
    def update_node_fn(self, nodes):
        x = torch.cat([nodes.data["feat"], nodes.data["agg"]], dim=1)
        ret = F.relu(self.node_fn(x)) if self.mode == "nne" else self.node_fn(x)
        return {"n": ret}

    def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):
        # print(node_feats.shape,global_feats.shape)
        g.ndata["feat"] = torch.cat([node_feats, global_feats], dim=1)
        g.edata["feat"] = torch.cat([edge_feats, relation_feats], dim=1)
        if self.mode == "nne":
            g.apply_edges(self.update_edge_fn)
        else:
            g.edata["e"] = self.edge_fn(g.edata["feat"])

        g.update_all(
            fn.copy_e("e", "msg"), fn.sum("msg", "agg"), self.update_node_fn
        )
        return g.ndata["n"], g.edata["e"]
