"""
Layers for the Simplicial CRaWl model.

Notes
-----
All layers in this module support TorchScript. The JIT-compiler does not support subtyping on the
built-in `list` for function parameters. We use the deprecated `typing.List` annotation instead.
"""
from enum import Enum
from typing import List

import torch
from torch.nn import (
    BatchNorm1d,
    Conv1d,
    Identity,
    Linear,
    Module,
    ModuleList,
    ReLU,
    Sequential,
)
from torch_scatter import scatter_mean, scatter_sum

from scrawl.simplicial import SimplicialData
from scrawl.walker import RandomWalk, Walker


class PoolingMethod(Enum):
    """
    Enumeration of the pooling methods.
    """

    MEAN = "mean"
    SUM = "sum"

    @property
    def scatter_function(self):
        """
        The scatter function for this pooling method.

        Returns
        -------
        callable
            The scatter function.
        """
        if self is PoolingMethod.MEAN:
            return scatter_mean
        elif self is PoolingMethod.SUM:
            return scatter_sum

    def scatter(self, *args, **kwargs) -> torch.Tensor:
        """
        Scatter the input according to the pooling method.

        Parameters
        ----------
        *args : positional arguments
            Positional arguments passed to the scatter function.
        **kwargs : keyword arguments
            Keyword arguments passed to the scatter function.

        Returns
        -------
        torch.Tensor
            The output of the scatter function.
        """
        return self.scatter_function()(*args, **kwargs)


class ConvolutionModule(Module):
    """
    SCRaWl convolution module.

    Parameters
    ----------
    feature_size : int
        Dimension of the input embeddings.
    walk_feature_size : int
        Size of the walk features, which include in addition to the simplex
        features also information about the connectivity and identity of
        the simplices seen in the walk.
    convolution_size : int
        Dimension of the convolution output.
    output_size : int
        Dimension of the output embeddings.
    kernel_size : int
        Size of the pooling kernel. Walklets of length `kernel_size` are
        pooled into its central node.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    """

    def __init__(
        self,
        feature_size: int,
        walk_feature_size: int,
        convolution_size: int,
        output_size: int,
        kernel_size: int,
        batch_running_stats: bool,
    ) -> None:
        """
        SCRaWl convolution module.

        Parameters
        ----------
        feature_size : int
            Dimension of the input embeddings.
        walk_feature_size : int
            Size of the walk features, which include in addition to the simplex
            features also information about the connectivity and identity of
            the simplices seen in the walk.
        convolution_size : int
            Dimension of the convolution output.
        output_size : int
            Dimension of the output embeddings.
        kernel_size : int
            Size of the pooling kernel. Walklets of length `kernel_size` are
            pooled into its central node.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        """
        super().__init__()

        self.feature_size = feature_size
        self.walk_feature_size = walk_feature_size
        self.convolution_size = convolution_size
        self.output_size = output_size
        self.kernel_size = kernel_size
        self.batch_running_stats = batch_running_stats

        self.convolutions = Sequential(
            Conv1d(
                self.walk_feature_size, self.convolution_size, 1, padding=0, bias=False
            ),
            Conv1d(
                self.convolution_size,
                self.convolution_size,
                self.kernel_size,
                groups=self.convolution_size,
                padding=0,
                bias=False,
            ),
            BatchNorm1d(
                self.convolution_size, track_running_stats=self.batch_running_stats
            ),
            ReLU(),
            Conv1d(
                self.convolution_size, self.convolution_size, 1, padding=0, bias=False
            ),
            ReLU(),
        )
        self.rescale = (
            Linear(self.feature_size, output_size, bias=False)
            if self.feature_size != output_size
            else Identity()
        )
        self.output = Sequential(
            Linear(convolution_size, 2 * output_size, bias=False),
            BatchNorm1d(2 * output_size, track_running_stats=self.batch_running_stats),
            ReLU(),
            Linear(2 * output_size, output_size, bias=False),
        )

    def forward(
        self,
        data: torch.Tensor,
        walk_simplices: torch.Tensor,
        walk_tensor: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of this layer.

        Parameters
        ----------
        data : torch.Tensor, shape = [num_simplices, feature_size]
            Tensor containing the input features.
        walk_simplices : torch.Tensor, shape = [num_walks, walk_length]
            Tensor containing the simplex indices along the walk.
        walk_tensor : torch.Tensor, shape = [num_walks, walk_length, walk_dim]
            Tensor containing the walk features.

        Returns
        -------
        torch.Tensor, shape = [num_simplices, output_size]
            The output of this layer.
        """
        # Step 1: Run 1D-Convolutions on the walk features.
        y = self.convolutions(walk_tensor.transpose(1, 2))
        y = y.transpose(1, 2)

        # Step 2: Pool the convolution outputs into the simplices, more
        #         precisely, poll the result into the central simplex of each
        #         receptive field.
        y_flatten = y.reshape(y.shape[0] * y.shape[1], -1)

        # get center indices
        pool_node = self.kernel_size // 2
        walk_simplices_flatt = walk_simplices[
            :, pool_node : -(self.kernel_size - 1 - pool_node)
        ].reshape(-1)

        # pool graphlet embeddings into nodes
        p_node = scatter_mean(
            y_flatten,
            walk_simplices_flatt.to(dtype=torch.int64),
            dim=0,
            dim_size=data.size(0),
        )

        # rescale for the residual connection
        data = self.rescale(data)
        data += self.output(p_node)

        return data


class SimplicialCRaWlLayer(Module):
    """
    Single SCRaWl layer.

    Parameters
    ----------
    feature_sizes : list of int
        Sizes of the features supported on the simplices of the respective
        ranks. Entry `i` corresponds to the size of the features of
        `i`-simplices.
    local_window_sizes : list of int
        Window sizes for the connectivity and identity features in the walk
        feature matrix. The local window size can be specified individually
        for each rank of simplices, or one window size for all ranks.
    convolution_sizes : list of int
        Sizes of the convolution outputs.
    embedding_sizes : list of int
        Sizes of the embeddings computed by the individual layers.
    kernel_sizes : list of int
        Size of the pooling kernels. Walklets on the `i`-simplices of length
        `kernel_sizes[i]` are pooled into their central simplex.
    dropout : float between 0 and 1
        Dropout probability.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    """

    feature_sizes: List[int]
    local_window_sizes: List[int]
    dropout: float

    def __init__(
        self,
        feature_sizes: List[int],
        local_window_sizes: List[int],
        convolution_sizes: List[int],
        embedding_sizes: List[int],
        kernel_sizes: List[int],
        *,
        dropout: float = 0.0,
        batch_running_stats: bool = True,
    ) -> None:
        """
        Single SCRaWl layer.

        Parameters
        ----------
        feature_sizes : list of int
            Sizes of the features supported on the simplices of the respective
            ranks. Entry `i` corresponds to the size of the features of
            `i`-simplices.
        local_window_sizes : list of int
            Window sizes for the connectivity and identity features in the walk
            feature matrix. The local window size can be specified individually
            for each rank of simplices, or one window size for all ranks.
        convolution_sizes : list of int
            Sizes of the convolution outputs.
        embedding_sizes : list of int
            Sizes of the embeddings computed by the individual layers.
        kernel_sizes : list of int
            Size of the pooling kernels. Walklets on the `i`-simplices of length
            `kernel_sizes[i]` are pooled into their central simplex.
        dropout : float between 0 and 1
            Dropout probability.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        """
        super().__init__()

        self.feature_sizes = feature_sizes
        self.local_window_sizes = local_window_sizes
        self.convolution_sizes = convolution_sizes
        self.embedding_sizes = embedding_sizes
        self.kernel_sizes = kernel_sizes
        self.dropout = dropout
        self.batch_running_stats = batch_running_stats

        self.walk_tensor_sizes = [
            RandomWalk.walk_feature_size(
                self.local_window_sizes[rank],
                self.feature_sizes[rank],
                (self.feature_sizes[rank - 1] if rank > 0 else 0),
                (
                    self.feature_sizes[rank + 1]
                    if rank < len(self.feature_sizes) - 1
                    else 0
                ),
            )
            for rank in range(len(feature_sizes))
        ]

        self.convolutions = ModuleList()
        for rank in range(len(feature_sizes)):
            self.convolutions.append(
                ConvolutionModule(
                    feature_size=self.feature_sizes[rank],
                    walk_feature_size=self.walk_tensor_sizes[rank],
                    convolution_size=self.convolution_sizes[rank],
                    output_size=self.embedding_sizes[rank],
                    kernel_size=self.kernel_sizes[rank],
                    batch_running_stats=self.batch_running_stats,
                )
            )

    def forward(
        self, data: SimplicialData, random_walks: dict[int, list[RandomWalk]]
    ) -> SimplicialData:
        """
        Forward pass of this layer.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        random_walks : dict of list of RandomWalk instances
            Lists of random walks, grouped by rank of simplices they walk on.

        Returns
        -------
        SimplicialData
            The output of this layer.
        """
        # construct stacked walk tensors, one for each rank of simplices
        walk_matrices: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
        for rank, walks in random_walks.items():
            walk_simplices: list[torch.Tensor] = []
            feature_matrices: list[torch.Tensor] = []
            for walk in walks:
                a, b = walk.feature_matrix(
                    data,
                    self.local_window_sizes[rank],
                    self.feature_sizes[rank - 1] if rank > 0 else 0,
                    self.feature_sizes[rank + 1]
                    if rank < len(self.feature_sizes) - 1
                    else 0,
                )
                walk_simplices.append(a)
                feature_matrices.append(b)

            walk_matrices[rank] = (
                torch.stack(walk_simplices).to(data.device),
                torch.stack(feature_matrices).to(data.device),
            )

        # apply convolutions to each walk tensor
        output = SimplicialData(data.domain, dtype=data.dtype, device=data.device)
        for rank, convolution in enumerate(self.convolutions):
            if rank in walk_matrices:
                output[rank] = convolution(
                    data[rank], walk_matrices[rank][0], walk_matrices[rank][1]
                )

        return output


class SimplicialCRaWlNodeOutput(Module):
    """
    Node output layer for SCRaWl.

    Parameters
    ----------
    embedding_size : int
        The size of the embedding.
    num_outputs : int
        The number of outputs.
    """

    def __init__(self, embedding_size: int, num_outputs: int) -> None:
        """
        Node output layer for SCRaWl.

        Parameters
        ----------
        embedding_size : int
            The size of the embedding.
        num_outputs : int
            The number of outputs.
        """
        super().__init__()

        self.output = Sequential(
            Linear(embedding_size, embedding_size),
            ReLU(),
            Linear(embedding_size, num_outputs),
        )

    def forward(self, data: SimplicialData) -> torch.Tensor:
        """
        Forward pass of this layer.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.

        Returns
        -------
        torch.Tensor
            The output of this layer.
        """
        return self.output(data[0])


class SimplicialCRaWlWalked(Module):
    """
    SCRaWl model with pre-computed random walks.

    Parameters
    ----------
    num_layers : int
        The number of concurrent layers used in SimplicialCRaWl. Other
        hyperparameters must have appropriate lengths for this number of
        layers.
    feature_sizes : list of ints
        The initial feature sizes of the simplices. Entry `i` corresponds
        to the size of features supported in the `i`-simplices.
    local_window_sizes : list of list of int
        The local window sizes for identity and connectivity features in
        the random walks. Entry `i, j` corresponds to the local window size
        of `j`-simplices in the `i`-th layer.

        The size of the outer list must be equal to `num_layers` and the
        size of the inner lists must be equal to the size of
        `feature_sizes`.
    embedding_sizes : list of list of int
        The sizes of the embeddings computed by the individual layers.
        Entry `i, j` corresponds to the embedding size of `j`-simplices
        computed by the `i`-th layer of the network.

        The size of the outer list must be equal to `num_layers` and the
        size of the inner lists must be equal to the size of
        `feature_sizes`.
    kernel_sizes : list of int
        Size of the pooling kernels.
    pooling : PoolingMethod, default=PoolingMethod.MEAN
        The pooling method to use.
    dropout : float between 0 and 1
        Dropout probability.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    """

    def __init__(
        self,
        num_layers: int,
        feature_sizes: List[int] | int,
        local_window_sizes: List[List[int]] | List[int],
        embedding_sizes: List[List[int]] | List[int],
        kernel_sizes: List[int],
        *,
        pooling: PoolingMethod = PoolingMethod.MEAN,
        dropout: float = 0.0,
        batch_running_stats: bool = True,
    ) -> None:
        """
        SCRaWl model with pre-computed random walks.

        Parameters
        ----------
        num_layers : int
            The number of concurrent layers used in SimplicialCRaWl. Other
            hyperparameters must have appropriate lengths for this number of
            layers.
        feature_sizes : list of ints
            The initial feature sizes of the simplices. Entry `i` corresponds
            to the size of features supported in the `i`-simplices.
        local_window_sizes : list of list of int
            The local window sizes for identity and connectivity features in
            the random walks. Entry `i, j` corresponds to the local window size
            of `j`-simplices in the `i`-th layer.

            The size of the outer list must be equal to `num_layers` and the
            size of the inner lists must be equal to the size of
            `feature_sizes`.
        embedding_sizes : list of list of int
            The sizes of the embeddings computed by the individual layers.
            Entry `i, j` corresponds to the embedding size of `j`-simplices
            computed by the `i`-th layer of the network.

            The size of the outer list must be equal to `num_layers` and the
            size of the inner lists must be equal to the size of
            `feature_sizes`.
        kernel_sizes : list of int
            Size of the pooling kernels.
        pooling : PoolingMethod, default=PoolingMethod.MEAN
            The pooling method to use.
        dropout : float between 0 and 1
            Dropout probability.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        """
        super().__init__()

        if isinstance(local_window_sizes, int):
            local_window_sizes = [local_window_sizes] * num_layers
        elif len(local_window_sizes) != num_layers:
            raise ValueError(
                f"The number of local window sizes must match the number of "
                f"layers. Got {len(local_window_sizes)} entries, but expected {num_layers}."
            )

        for i, layer_window_sizes in enumerate(local_window_sizes):
            if isinstance(layer_window_sizes, int):
                local_window_sizes[i] = [layer_window_sizes] * len(feature_sizes)
            elif len(layer_window_sizes) != len(feature_sizes):
                raise ValueError(
                    f"Wrong number of local window sizes for layer {i}. Got "
                    f"{len(layer_window_sizes)} but expected {len(feature_sizes)} entries."
                )

        if len(embedding_sizes) != num_layers:
            raise ValueError(
                "The number of embedding sizes must match the number of layers."
            )
        for i, layer_embedding_sizes in enumerate(embedding_sizes):
            if isinstance(layer_embedding_sizes, int):
                embedding_sizes[i] = [layer_embedding_sizes] * len(feature_sizes)
            elif len(layer_embedding_sizes) != len(feature_sizes):
                raise ValueError(
                    f"Wrong number of embedding sizes for layer {i}. Got "
                    f"{len(layer_embedding_sizes)} but expected {len(feature_sizes)} entries."
                )

        self.num_layers = num_layers
        self.feature_sizes = feature_sizes
        self.local_window_sizes = local_window_sizes
        self.embedding_sizes = embedding_sizes
        self.kernel_sizes = kernel_sizes
        self.pooling = pooling
        self.dropout = dropout
        self.batch_running_stats = batch_running_stats

        self.conv_dim = (
            self.embedding_sizes
        )  # config['conv_dim'] if 'conv_dim' in config.keys() else self.hidden

        convolution_sizes = embedding_sizes  # TODO

        self.layers = ModuleList()
        for i in range(self.num_layers):
            self.layers.append(
                SimplicialCRaWlLayer(
                    feature_sizes if i == 0 else embedding_sizes[i - 1],
                    local_window_sizes[i],
                    convolution_sizes[i],
                    embedding_sizes[i],
                    kernel_sizes,
                    dropout=dropout,
                    batch_running_stats=batch_running_stats,
                )
            )

    def forward(
        self, data: SimplicialData, walks: dict[int, list[RandomWalk]]
    ) -> SimplicialData:
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        walks : dict of list of RandomWalk instances
            Lists of random walks, grouped by rank of simplices they walk on.

        Returns
        -------
        SimplicialData
            The output of the model.
        """
        for layer in self.layers:
            data = layer(data, walks)

        return data


class SimplicialCRaWl(Module):
    """
    SCRaWl model.

    Parameters
    ----------
    num_layers : int
        The number of concurrent layers used in SimplicialCRaWl. Other
        hyperparameters must have appropriate lengths for this number of
        layers.
    feature_sizes : list of ints
        The initial feature sizes of the simplices. Entry `i` corresponds
        to the size of features supported in the `i`-simplices.
    local_window_sizes : list of list of int
        The local window sizes for identity and connectivity features in
        the random walks. Entry `i, j` corresponds to the local window size
        of `j`-simplices in the `i`-th layer.

        The size of the outer list must be equal to `num_layers` and the
        size of the inner lists must be equal to the size of
        `feature_sizes`.
    embedding_sizes : list of list of int
        The sizes of the embeddings computed by the individual layers.
        Entry `i, j` corresponds to the embedding size of `j`-simplices
        computed by the `i`-th layer of the network.

        The size of the outer list must be equal to `num_layers` and the
        size of the inner lists must be equal to the size of
        `feature_sizes`.
    kernel_sizes : list of int
        Size of the pooling kernels.
    walk_prop : float between 0 and 1
        Probability to start a random walk from each simplex.
    pooling : PoolingMethod, default=PoolingMethod.MEAN
        The pooling method to use.
    dropout : float between 0 and 1
        Dropout probability.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    use_lower_connections : bool, default=True
        Whether to use lower connections.
    use_upper_connections : bool, default=True
        Whether to use upper connections.
    """

    def __init__(
        self,
        num_layers: int,
        feature_sizes: List[int],
        local_window_sizes: list[List[int]] | List[int],
        embedding_sizes: list[List[int]] | List[int],
        kernel_sizes: List[int],
        *,
        walk_prop: float = 1.0,
        pooling: PoolingMethod = PoolingMethod.MEAN,
        dropout: float = 0.0,
        batch_running_stats: bool = True,
        use_lower_connections: bool = True,
        use_upper_connections: bool = True,
    ) -> None:
        """
        SCRaWl model.

        Parameters
        ----------
        num_layers : int
            The number of concurrent layers used in SimplicialCRaWl. Other
            hyperparameters must have appropriate lengths for this number of
            layers.
        feature_sizes : list of ints
            The initial feature sizes of the simplices. Entry `i` corresponds
            to the size of features supported in the `i`-simplices.
        local_window_sizes : list of list of int
            The local window sizes for identity and connectivity features in
            the random walks. Entry `i, j` corresponds to the local window size
            of `j`-simplices in the `i`-th layer.

            The size of the outer list must be equal to `num_layers` and the
            size of the inner lists must be equal to the size of
            `feature_sizes`.
        embedding_sizes : list of list of int
            The sizes of the embeddings computed by the individual layers.
            Entry `i, j` corresponds to the embedding size of `j`-simplices
            computed by the `i`-th layer of the network.

            The size of the outer list must be equal to `num_layers` and the
            size of the inner lists must be equal to the size of
            `feature_sizes`.
        kernel_sizes : list of int
            Size of the pooling kernels.
        walk_prop : float between 0 and 1
            Probability to start a random walk from each simplex.
        pooling : PoolingMethod, default=PoolingMethod.MEAN
            The pooling method to use.
        dropout : float between 0 and 1
            Dropout probability.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        use_lower_connections : bool, default=True
            Whether to use lower connections.
        use_upper_connections : bool, default=True
            Whether to use upper connections.
        """
        super().__init__()

        self.walk_prop = walk_prop
        self.use_lower_connections = use_lower_connections
        self.use_upper_connections = use_upper_connections

        self.inner = SimplicialCRaWlWalked(
            num_layers,
            feature_sizes,
            local_window_sizes,
            embedding_sizes,
            kernel_sizes,
            pooling=pooling,
            dropout=dropout,
            batch_running_stats=batch_running_stats,
        )

    def forward(self, data: SimplicialData, walks: int) -> SimplicialData:
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        walks : int
            The number of walks to use.

        Returns
        -------
        SimplicialData
            The output of the model.
        """
        # compute walks
        walker = Walker(
            data.domain,
            use_lower_connections=self.use_lower_connections,
            use_upper_connections=self.use_upper_connections,
            max_rank=len(self.inner.feature_sizes),
        )
        walk_groups: dict[int, list[RandomWalk]] = {}
        for rank, num_simplices in zip(
            range(len(self.inner.feature_sizes)), data.domain.shape
        ):
            if self.walk_prop == 1.0:
                walk_groups[rank] = walker.random_walks(
                    rank, torch.arange(num_simplices, dtype=torch.int64), walks
                )
            else:
                walk_groups[rank] = walker.random_walks(
                    rank,
                    torch.randint(
                        0, num_simplices, (int(self.walk_prop * num_simplices),)
                    ),
                    walks,
                )

        return self.inner(data, walk_groups)
