"""GNS from https://github.com/wu375/simple-physics-simulator-pytorch-geometry."""

import torch
from torch import nn
from torch_geometric.nn import MessagePassing


def build_mlp(
    input_size,
    layer_sizes,
    output_size=None,
    output_activation=torch.nn.Identity,
    activation=torch.nn.ReLU,
):
    sizes = [input_size] + layer_sizes
    if output_size:
        sizes.append(output_size)

    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else output_activation
        layers += [torch.nn.Linear(sizes[i], sizes[i + 1]), act()]
    return torch.nn.Sequential(*layers)


def time_diff(input_sequence):
    return input_sequence[:, 1:] - input_sequence[:, :-1]


def get_random_walk_noise_for_position_sequence(position_sequence, noise_std_last_step):
    """Returns random-walk noise in the velocity applied to the position."""
    velocity_sequence = time_diff(position_sequence)
    num_velocities = velocity_sequence.shape[1]
    velocity_sequence_noise = torch.randn(list(velocity_sequence.shape)) * (
        noise_std_last_step / num_velocities**0.5
    )

    velocity_sequence_noise = torch.cumsum(velocity_sequence_noise, dim=1)

    position_sequence_noise = torch.cat(
        [
            torch.zeros_like(velocity_sequence_noise[:, 0:1]),
            torch.cumsum(velocity_sequence_noise, dim=1),
        ],
        dim=1,
    )

    return position_sequence_noise


class Encoder(nn.Module):
    def __init__(
        self,
        node_in,
        node_out,
        edge_in,
        edge_out,
        mlp_num_layers,
        mlp_hidden_dim,
    ):
        super().__init__()
        self.node_fn = nn.Sequential(
            *[
                build_mlp(node_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out),
                nn.LayerNorm(node_out),
            ]
        )
        self.edge_fn = nn.Sequential(
            *[
                build_mlp(edge_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], edge_out),
                nn.LayerNorm(edge_out),
            ]
        )

    def forward(self, x, edge_index, e_features):  # global_features
        # x: (E, node_in)
        # edge_index: (2, E)
        # e_features: (E, edge_in)
        return self.node_fn(x), self.edge_fn(e_features)


class InteractionNetwork(MessagePassing):
    def __init__(
        self,
        node_in,
        node_out,
        edge_in,
        edge_out,
        mlp_num_layers,
        mlp_hidden_dim,
    ):
        super().__init__(aggr="add")
        self.node_fn = nn.Sequential(
            *[
                build_mlp(
                    node_in + edge_out, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out
                ),
                nn.LayerNorm(node_out),
            ]
        )
        self.edge_fn = nn.Sequential(
            *[
                build_mlp(
                    node_in + node_in + edge_in,
                    [mlp_hidden_dim for _ in range(mlp_num_layers)],
                    edge_out,
                ),
                nn.LayerNorm(edge_out),
            ]
        )

    def forward(self, x, edge_index, e_features):
        # x: (E, node_in)
        # edge_index: (2, E)
        # e_features: (E, edge_in)
        x_residual = x
        e_features_residual = e_features
        x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features)
        return x + x_residual, e_features + e_features_residual

    def message(self, edge_index, x_i, x_j, e_features):
        e_features = torch.cat([x_i, x_j, e_features], dim=-1)
        e_features = self.edge_fn(e_features)
        return e_features

    def update(self, x_updated, x, e_features):
        # x_updated: (E, edge_out)
        # x: (E, node_in)
        x_updated = torch.cat([x_updated, x], dim=-1)
        x_updated = self.node_fn(x_updated)
        return x_updated, e_features


class Processor(MessagePassing):
    def __init__(
        self,
        node_in,
        node_out,
        edge_in,
        edge_out,
        num_message_passing_steps,
        mlp_num_layers,
        mlp_hidden_dim,
    ):
        super().__init__(aggr="max")
        self.gnn_stacks = nn.ModuleList(
            [
                InteractionNetwork(
                    node_in=node_in,
                    node_out=node_out,
                    edge_in=edge_in,
                    edge_out=edge_out,
                    mlp_num_layers=mlp_num_layers,
                    mlp_hidden_dim=mlp_hidden_dim,
                )
                for _ in range(num_message_passing_steps)
            ]
        )

    def forward(self, x, edge_index, e_features):
        for gnn in self.gnn_stacks:
            x, e_features = gnn(x, edge_index, e_features)
        return x, e_features


class Decoder(nn.Module):
    def __init__(
        self,
        node_in,
        node_out,
        mlp_num_layers,
        mlp_hidden_dim,
    ):
        super().__init__()
        self.node_fn = build_mlp(
            node_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out
        )

    def forward(self, x):
        # x: (E, node_in)
        return self.node_fn(x)


class EncodeProcessDecode(nn.Module):
    def __init__(
        self,
        node_in,
        node_out,
        edge_in,
        latent_dim,
        num_message_passing_steps,
        mlp_num_layers,
        mlp_hidden_dim,
    ):
        super().__init__()
        self._encoder = Encoder(
            node_in=node_in,
            node_out=latent_dim,
            edge_in=edge_in,
            edge_out=latent_dim,
            mlp_num_layers=mlp_num_layers,
            mlp_hidden_dim=mlp_hidden_dim,
        )
        self._processor = Processor(
            node_in=latent_dim,
            node_out=latent_dim,
            edge_in=latent_dim,
            edge_out=latent_dim,
            num_message_passing_steps=num_message_passing_steps,
            mlp_num_layers=mlp_num_layers,
            mlp_hidden_dim=mlp_hidden_dim,
        )
        self._decoder = Decoder(
            node_in=latent_dim,
            node_out=node_out,
            mlp_num_layers=mlp_num_layers,
            mlp_hidden_dim=mlp_hidden_dim,
        )

    def forward(self, x, edge_index, e_features):
        # x: (E, node_in)
        x, e_features = self._encoder(x, edge_index, e_features)
        x, e_features = self._processor(x, edge_index, e_features)
        x = self._decoder(x)
        return x
