from typing import Optional

import torch
import torch.nn
from torch import Tensor

import dgl
from dgl import DGLGraph
from dgl.nn import GraphConv


class GraphConvWithDropout(GraphConv):
    """
    A GraphConv followed by a Dropout.
    """

    def __init__(
        self,
        in_feats,
        out_feats,
        dropout=0.3,
        norm="both",
        weight=True,
        bias=True,
        activation=None,
        allow_zero_in_degree=False,
    ):
        super(GraphConvWithDropout, self).__init__(
            in_feats,
            out_feats,
            norm,
            weight,
            bias,
            activation,
            allow_zero_in_degree,
        )
        self.dropout = torch.nn.Dropout(p=dropout)

    def call(self, graph, feat, weight=None):
        feat = self.dropout(feat)
        return super(GraphConvWithDropout, self).call(graph, feat, weight)


class Discriminator(torch.nn.Module):
    """
    Description
    -----------
    A discriminator used to let the network to discrimate
    between positive (neighborhood of center node) and
    negative (any neighborhood in graph) samplings.

    Parameters
    ----------
    feat_dim : int
        The number of channels of node features.
    """

    def __init__(self, feat_dim: int):
        super(Discriminator, self).__init__()
        self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.affine.weight)
        torch.nn.init.zeros_(self.affine.bias)

    def forward(
        self,
        h_x: Tensor,
        h_pos: Tensor,
        h_neg: Tensor,
        bias_pos: Optional[Tensor] = None,
        bias_neg: Optional[Tensor] = None,
    ):
        """
        Parameters
        ----------
        h_x : torch.Tensor
            Node features, shape: :obj:`(num_nodes, feat_dim)`
        h_pos : torch.Tensor
            The node features of positive samples
            It has the same shape as :obj:`h_x`
        h_neg : torch.Tensor
            The node features of negative samples
            It has the same shape as :obj:`h_x`
        bias_pos : torch.Tensor
            Bias parameter vector for positive scores
            shape: :obj:`(num_nodes)`
        bias_neg : torch.Tensor
            Bias parameter vector for negative scores
            shape: :obj:`(num_nodes)`

        Returns
        -------
        (torch.Tensor, torch.Tensor)
            The output scores with shape (2 * num_nodes,), (num_nodes,)
        """
        score_pos = self.affine(h_pos, h_x).squeeze()
        score_neg = self.affine(h_neg, h_x).squeeze()
        if bias_pos is not None:
            score_pos = score_pos + bias_pos
        if bias_neg is not None:
            score_neg = score_neg + bias_neg

        logits = torch.cat((score_pos, score_neg), 0)

        return logits, score_pos


class DenseLayer(torch.nn.Module):
    """
    Description
    -----------
    Dense layer with a linear layer and an activation function
    """

    def __init__(
        self, in_dim: int, out_dim: int, act: str = "prelu", bias=True
    ):
        super(DenseLayer, self).__init__()
        self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)
        self.act_type = act.lower()
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        if self.lin.bias is not None:
            torch.nn.init.zeros_(self.lin.bias)
        if self.act_type == "prelu":
            self.act = torch.nn.PReLU()
        else:
            self.act = torch.relu

    def forward(self, x):
        x = self.lin(x)
        return self.act(x)


class IndexSelect(torch.nn.Module):
    """
    Description
    -----------
    The index selection layer used by VIPool

    Parameters
    ----------
    pool_ratio : float
        The pooling ratio (for keeping nodes). For example,
        if `pool_ratio=0.8`, 80\% nodes will be preserved.
    hidden_dim : int
        The number of channels in node features.
    act : str, optional
        The activation function type.
        Default: :obj:`'prelu'`
    dist : int, optional
        DO NOT USE THIS PARAMETER
    """

    def __init__(
        self,
        pool_ratio: float,
        hidden_dim: int,
        act: str = "prelu",
        dist: int = 1,
    ):
        super(IndexSelect, self).__init__()
        self.pool_ratio = pool_ratio
        self.dist = dist
        self.dense = DenseLayer(hidden_dim, hidden_dim, act)
        self.discriminator = Discriminator(hidden_dim)
        self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

    def forward(
        self,
        graph: DGLGraph,
        h_pos: Tensor,
        h_neg: Tensor,
        bias_pos: Optional[Tensor] = None,
        bias_neg: Optional[Tensor] = None,
    ):
        """
        Description
        -----------
        Perform index selection

        Parameters
        ----------
        graph : dgl.DGLGraph
            Input graph.
        h_pos : torch.Tensor
            The node features of positive samples
            It has the same shape as :obj:`h_x`
        h_neg : torch.Tensor
            The node features of negative samples
            It has the same shape as :obj:`h_x`
        bias_pos : torch.Tensor
            Bias parameter vector for positive scores
            shape: :obj:`(num_nodes)`
        bias_neg : torch.Tensor
            Bias parameter vector for negative scores
            shape: :obj:`(num_nodes)`
        """
        # compute scores
        h_pos = self.dense(h_pos)
        h_neg = self.dense(h_neg)
        embed = self.gcn(graph, h_pos)
        h_center = torch.sigmoid(embed)

        logit, logit_pos = self.discriminator(
            h_center, h_pos, h_neg, bias_pos, bias_neg
        )
        scores = torch.sigmoid(logit_pos)

        # sort scores
        scores, idx = torch.sort(scores, descending=True)

        # select top-k
        num_nodes = graph.num_nodes()
        num_select_nodes = int(self.pool_ratio * num_nodes)
        size_list = [num_select_nodes, num_nodes - num_select_nodes]
        select_scores, _ = torch.split(scores, size_list, dim=0)
        select_idx, non_select_idx = torch.split(idx, size_list, dim=0)

        return logit, select_scores, select_idx, non_select_idx, embed


class GraphPool(torch.nn.Module):
    """
    Description
    -----------
    The pooling module for graph

    Parameters
    ----------
    hidden_dim : int
        The number of channels of node features.
    use_gcn : bool, optional
        Whether use gcn in down sampling process.
        default: :obj:`False`
    """

    def __init__(self, hidden_dim: int, use_gcn=False):
        super(GraphPool, self).__init__()
        self.use_gcn = use_gcn
        self.down_sample_gcn = (
            GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None
        )

    def forward(
        self,
        graph: DGLGraph,
        feat: Tensor,
        select_idx: Tensor,
        non_select_idx: Optional[Tensor] = None,
        scores: Optional[Tensor] = None,
        pool_graph=False,
    ):
        """
        Description
        -----------
        Perform graph pooling.

        Parameters
        ----------
        graph : dgl.DGLGraph
            The input graph
        feat : torch.Tensor
            The input node feature
        select_idx : torch.Tensor
            The index in fine graph of node from
            coarse graph, this is obtained from
            previous graph pooling layers.
        non_select_idx : torch.Tensor, optional
            The index that not included in output graph.
            default: :obj:`None`
        scores : torch.Tensor, optional
            Scores for nodes used for pooling and scaling.
            default: :obj:`None`
        pool_graph : bool, optional
            Whether perform graph pooling on graph topology.
            default: :obj:`False`
        """
        if self.use_gcn:
            feat = self.down_sample_gcn(graph, feat)

        feat = feat[select_idx]
        if scores is not None:
            feat = feat * scores.unsqueeze(-1)

        if pool_graph:
            num_node_batch = graph.batch_num_nodes()
            graph = dgl.node_subgraph(graph, select_idx)
            graph.set_batch_num_nodes(num_node_batch)
            return feat, graph
        else:
            return feat


class GraphUnpool(torch.nn.Module):
    """
    Description
    -----------
    The unpooling module for graph

    Parameters
    ----------
    hidden_dim : int
        The number of channels of node features.
    """

    def __init__(self, hidden_dim: int):
        super(GraphUnpool, self).__init__()
        self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

    def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor):
        """
        Description
        -----------
        Perform graph unpooling

        Parameters
        ----------
        graph : dgl.DGLGraph
            The input graph
        feat : torch.Tensor
            The input node feature
        select_idx : torch.Tensor
            The index in fine graph of node from
            coarse graph, this is obtained from
            previous graph pooling layers.
        """
        fine_feat = torch.zeros(
            (graph.num_nodes(), feat.size(-1)), device=feat.device
        )
        fine_feat[select_idx] = feat
        fine_feat = self.up_sample_gcn(graph, fine_feat)
        return fine_feat
