from typing import Tuple, Optional, List, Callable

import torch
import torch_geometric
from pydgn.model.interface import ModelInterface
from torch import relu
from torch.nn import Sequential, Linear, ReLU, ModuleList
from torch_geometric.data import Batch
from torch_geometric.nn import MessagePassing, GIN, GCNConv
from torch_geometric.nn import knn_graph


class MLP(ModelInterface):
    def __init__(
        self,
        dim_node_features: int,
        dim_edge_features: int,
        dim_target: int,
        readout_class: Callable[..., torch.nn.Module],
        config: dict,
    ):
        super().__init__(
            dim_node_features,
            dim_edge_features,
            dim_target,
            readout_class,
            config,
        )

        self.num_layers = config["num_layers"]
        self.hidden_units = config["hidden_units"]

        layers = (
            [Linear(dim_node_features, self.hidden_units)]
            + [
                Linear(self.hidden_units, self.hidden_units)
                for _ in range(self.num_layers)
            ]
            + [Linear(self.hidden_units, dim_target)]
        )
        self.layers = ModuleList(layers)

    def forward(
        self, data: Batch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[object]]]:

        h = relu(self.layers[0](data.x))

        for l in range(1, len(self.layers) - 1):
            h = relu(self.layers[l](h))

        o = self.layers[-1](h)

        # we need to compute the subset of [tr/val/test] output and target
        # values using the indices provided by the data loader
        if self.training:
            o = o[data.training_indices]
            h = h[data.training_indices]
            y = data.y[data.training_indices]
        else:
            o = o[data.eval_indices]
            h = h[data.eval_indices]
            y = data.y[data.eval_indices]

        return o, h, [y]


class SimpleDGNConv(MessagePassing):
    """
    Simply computes mean aggregation of neighbors
    """

    def __init__(self):
        super().__init__(aggr="mean")

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        return out

    def message(self, x_j):
        return x_j


class SimpleDGN(ModelInterface):
    def __init__(
        self,
        dim_node_features: int,
        dim_edge_features: int,
        dim_target: int,
        readout_class: Callable[..., torch.nn.Module],
        config: dict,
    ):
        super().__init__(
            dim_node_features,
            dim_edge_features,
            dim_target,
            readout_class,
            config,
        )

        self.num_layers = config["num_layers"]
        self.hidden_units = config["hidden_units"]
        self.k = config["k"]

        self.mean_aggregation = SimpleDGNConv()

        layers = (
            [Linear(dim_node_features, self.hidden_units)]
            + [
                Linear(self.hidden_units, self.hidden_units)
                for _ in range(self.num_layers)
            ]
            + [Linear(self.hidden_units, dim_target)]
        )
        self.layers = ModuleList(layers)

        self.knn_edge_index = None

    def forward(
        self, data: Batch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[object]]]:

        if self.knn_edge_index is None:
            # build knn graph and store it in the model
            self.knn_edge_index = knn_graph(
                data.x, self.k, loop=True, batch=None  # use dest node as well
            )

        # compute mean aggregation of neighbors in input space
        x = self.mean_aggregation(data.x, self.knn_edge_index)

        # same code as MLP above (the classifier)
        h = relu(self.layers[0](x))

        for l in range(1, len(self.layers) - 1):
            h = relu(self.layers[l](h))

        o = self.layers[-1](h)

        # we need to compute the subset of [tr/val/test] output and target
        # values using the indices provided by the data loader
        if self.training:
            o = o[data.training_indices]
            h = h[data.training_indices]
            y = data.y[data.training_indices]
        else:
            o = o[data.eval_indices]
            h = h[data.eval_indices]
            y = data.y[data.eval_indices]

        return o, h, [y]


class GIN(ModelInterface):
    def __init__(
        self,
        dim_node_features: int,
        dim_edge_features: int,
        dim_target: int,
        readout_class: Callable[..., torch.nn.Module],
        config: dict,
    ):
        super().__init__(
            dim_node_features,
            dim_edge_features,
            dim_target,
            readout_class,
            config,
        )

        self.num_layers = config["num_layers"]
        self.hidden_units = config["hidden_units"]
        self.k = config["k"]
        self.dropout = config["dropout"]
        self.aggregation = config["aggregation"]

        self.gin = torch_geometric.nn.GIN(
            dim_node_features,
            self.hidden_units,
            self.num_layers,
            dim_target,
            self.dropout,
            jk="cat",
            train_eps=True,
            eps=1.0,
        )

        # change aggregation method to hyper-parameter value
        for l in range(len(self.gin.convs)):
            self.gin.convs[l].aggr = self.aggregation

        self.knn_edge_index = None

    def forward(
        self, data: Batch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[object]]]:

        if self.knn_edge_index is None:
            # build knn graph and store it in the model
            self.knn_edge_index = knn_graph(
                data.x,
                self.k,
                loop=False,  # GIN already uses dest node
                batch=None,
            )

        o = self.gin(data.x, self.knn_edge_index)

        # we need to compute the subset of [tr/val/test] output and target
        # values using the indices provided by the data loader
        if self.training:
            o = o[data.training_indices]
            h = o  # not used
            y = data.y[data.training_indices]
        else:
            o = o[data.eval_indices]
            h = o  # not used
            y = data.y[data.eval_indices]

        return o, h, [y]


class GCN(ModelInterface):
    def __init__(
        self,
        dim_node_features: int,
        dim_edge_features: int,
        dim_target: int,
        readout_class: Callable[..., torch.nn.Module],
        config: dict,
    ):
        super().__init__(
            dim_node_features,
            dim_edge_features,
            dim_target,
            readout_class,
            config,
        )

        self.num_layers = config["num_layers"]
        self.hidden_units = config["hidden_units"]
        self.k = config["k"]

        # This code might not correspond to the exact implementation of the
        # original GCN paper
        # self.gcn = torch_geometric.nn.GCN(
        #     dim_node_features,
        #     self.hidden_units,
        #     self.num_layers,
        #     dim_target,
        #     self.dropout,
        #     jk="cat",
        #     train_eps=True,
        #     eps=1.0,
        # )

        layers = []
        # change aggregation method to hyper-parameter value
        for l in range(self.num_layers):
            layers.append(GCNConv(dim_node_features if l == 0 else self.hidden_units,
                                  self.hidden_units, cached=True))

        self.layers = ModuleList(layers)

        self.knn_edge_index = None


    def forward(
        self, data: Batch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[object]]]:

        if self.knn_edge_index is None:
            # build knn graph and store it in the model
            self.knn_edge_index = knn_graph(
                data.x,
                self.k,
                loop=False,  # GIN already uses dest node
                batch=None,
            )

        for l in range(self.num_layers):
            if l == 0:
                o = torch.relu(self.layers[0](data.x, self.knn_edge_index))
            elif l == self.num_layers - 1:
                # output class responsibilities
                o = self.layers[l](o, self.knn_edge_index)
            else:
                # intermediate layer
                o = torch.relu(self.layers[l](o, self.knn_edge_index))

        # we need to compute the subset of [tr/val/test] output and target
        # values using the indices provided by the data loader
        if self.training:
            o = o[data.training_indices]
            h = o  # not used
            y = data.y[data.training_indices]
        else:
            o = o[data.eval_indices]
            h = o  # not used
            y = data.y[data.eval_indices]

        return o, h, [y]