from typing import Optional, List

import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import Parameter
from torch_geometric.nn import GCNConv

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import add_self_loops as add_self_loops_fn
from torch_geometric.utils import (
    is_torch_sparse_tensor,
    scatter,
    spmm,
    to_edge_index,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes
import torch_geometric as pyg

from graphgps.layer.utils import dag_degree


class DirGCNConv(torch.nn.Module):
    def __init__(self, input_dim, output_dim, alpha):
        super(DirGCNConv, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.lin_src_to_dst = GCNConv(input_dim, output_dim, bias=False, normalize=False,
                                   add_self_loops=False)
        self.lin_dst_to_src = GCNConv(input_dim, output_dim, bias=False, normalize=False,
                                   add_self_loops=False)
        self.alpha = alpha
        self.edge_index_1, self.edge_index_t = None, None

    def forward(self, x, edge_index):
        if self.edge_index_1 is None:
            num_nodes = x.shape[0]

            self.edge_index_1, self.edge_weight_1 = directed_norm(edge_index, num_nodes)
            #self.adj_norm = SparseTensor(row=edge_index_1[0], col=edge_index_1[1], value=edge_weight_1, sparse_sizes=(num_nodes, num_nodes))

            if self.alpha > 0:
                self.edge_index_t, self.edge_weight_t = directed_norm(edge_index.index_select(0,torch.LongTensor([1,0]).to(edge_index.device)), num_nodes)
                #self.adj_t_norm = SparseTensor(row=edge_index_t[0], col=edge_index_t[1], value=edge_weight_t, sparse_sizes=(num_nodes, num_nodes))

        if self.alpha == 0:
            return self.lin_src_to_dst(x, self.edge_index_1, self.edge_weight_1)
        else:
            return (1 - self.alpha) * self.lin_src_to_dst(x, self.edge_index_1, self.edge_weight_1) + self.alpha * self.lin_dst_to_src(
                x, self.edge_index_t, self.edge_weight_t
            )

def directed_norm(edge_index, num_nodes=None, dtype=None):
    """
    Applies the normalization for directed graphs:
        \mathbf{D}_{out}^{-1/2} \mathbf{A} \mathbf{D}_{in}^{-1/2}.
    """
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                             device=edge_index.device)
    row, col = edge_index[0], edge_index[1]
    in_deg = scatter(edge_weight, col, dim=0, dim_size=num_nodes, reduce='sum')
    in_deg_inv_sqrt = in_deg.pow_(-0.5)
    in_deg_inv_sqrt.masked_fill_(in_deg_inv_sqrt == float('inf'), 1.0)

    out_deg = scatter(edge_weight, row, dim=0, dim_size=num_nodes, reduce='sum')
    out_deg_inv_sqrt = out_deg.pow_(-0.5)
    out_deg_inv_sqrt.masked_fill_(out_deg_inv_sqrt == float('inf'), 1.0)

    edge_weight = in_deg_inv_sqrt[col] * edge_weight * out_deg_inv_sqrt[row]

    return edge_index, edge_weight


class DAGGCNConv(MessagePassing):
    r"""The graph convolutional operator from the `"Semi-supervised
    Classification with Graph Convolutional Networks"
    <https://arxiv.org/abs/1609.02907>`_ paper

    .. math::
        \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},

    where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
    adjacency matrix with inserted self-loops and
    :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
    The adjacency matrix can include other values than :obj:`1` representing
    edge weights via the optional :obj:`edge_weight` tensor.

    Its node-wise formulation is given by:

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in
        \mathcal{N}(i) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j
        \hat{d}_i}} \mathbf{x}_j

    with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where
    :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target
    node :obj:`i` (default: :obj:`1.0`)

    Args:
        in_channels (int): Size of each input sample, or :obj:`-1` to derive
            the size from the first input(s) to the forward method.
        out_channels (int): Size of each output sample.
        improved (bool, optional): If set to :obj:`True`, the layer computes
            :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
            (default: :obj:`False`)
        cached (bool, optional): If set to :obj:`True`, the layer will cache
            the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
            \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
            cached version for further executions.
            This parameter should only be set to :obj:`True` in transductive
            learning scenarios. (default: :obj:`False`)
        add_self_loops (bool, optional): If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        normalize (bool, optional): Whether to add self-loops and compute
            symmetric normalization coefficients on the fly.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
          node features :math:`(|\mathcal{V}|, F_{in})`,
          edge indices :math:`(2, |\mathcal{E}|)`,
          edge weights :math:`(|\mathcal{E}|)` *(optional)*
        - **output:** node features :math:`(|\mathcal{V}|, F_{out})`
    """
    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        cached: bool = False,
        add_self_loops: bool = False,
        bias: bool = True,
        alpha: float = 0.0,
        share_init: bool = False,
        norm: str = 'sym',
        plus: bool = False,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.plus = plus

        self._cached_edge_index = None
        self._cached_adj_t = None
        self.alpha = alpha
        self.num_graphs = 3
        self.share_init = share_init
        self.edge_indices = None
        self.edge_weights = None
        self.norm = norm

        self.lins = torch.nn.ModuleList()

        if alpha > 0:
            self.lins.append(Linear(in_channels, out_channels*self.num_graphs, bias=False,
                              weight_initializer='glorot'))
        if alpha < 1:
            self.lins.append(Linear(in_channels, out_channels*self.num_graphs, bias=False,
                              weight_initializer='glorot'))

        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
            a = nn.init.calculate_gain('relu') * np.sqrt(6 / (self.in_channels + self.out_channels)) #* np.sqrt(1/3)
            torch.nn.init.uniform_(lin.weight.data, -a, a)
            if self.share_init:
                lin.weight = torch.nn.Parameter(torch.cat(
                [lin.weight[:self.out_channels], lin.weight[:self.out_channels],
                 lin.weight[:self.out_channels]], dim=0))
        zeros(self.bias)
        self._cached_edge_index = None
        self._cached_adj_t = None

    def init_graph(self, x, edge_index, ordering, edge_weight):

        num_nodes = x.shape[0]
        edge_indices = []
        edge_weights = []
        if self.norm == 'dir':
            ed_in, ed_we = directed_norm(edge_index, num_nodes)
        else:
            ed_in, ed_we = gcn_norm(edge_index, edge_weight, num_nodes, False, self.add_self_loops, 'source_to_target', x.dtype)
        edge_indices.append(ed_in)
        edge_weights.append(ed_we)
        if self.alpha > 0:
            if self.norm == 'dir':
                ed_in, ed_we = directed_norm(edge_index.index_select(0,torch.LongTensor([1,0]).to(edge_index.device)), num_nodes)#to_edge_index(adj_t)
            else:
                ed_in, ed_we = gcn_norm(edge_index.index_select(0,torch.LongTensor([1,0]).to(edge_index.device)), edge_weight, num_nodes, False, False, 'source_to_target', x.dtype)
            edge_indices.append(ed_in)
            edge_weights.append(ed_we)


        ordering = pyg.utils.degree(edge_index[0], x.size(0), dtype=torch.long)
        #if edge_weight is None:
        #    edge_index, edge_weight = directed_norm(edge_index, num_nodes=x.size(0), dtype=None)#gcn_norm(edge_index, None, x.size(0), False, False, 'source_to_target', x.dtype)
        edge_indices_list, edge_weights_list = [], []
        for i in range(len(edge_indices)):
            edge_ind_cur, indices_list = dag_degree(edge_indices[i], 3, x.size(0), ordering)
            edge_weights_cur = []
            for j in range(len(edge_ind_cur)):
                if self.plus and j < 2:
                    edge_ind_cur[j], edge_weight_dir = pyg.utils.add_remaining_self_loops(edge_ind_cur[j], edge_weights[i][indices_list[j]], 1, num_nodes)
                else:
                    edge_weight_dir = edge_weights[i][indices_list[j]]
                if i == 1:
                    edge_weights_cur.append(self.alpha * edge_weight_dir)
                else:
                    edge_weights_cur.append((1-self.alpha) * edge_weight_dir)
            edge_indices_list.append(edge_ind_cur)
            edge_weights_list.append(edge_weights_cur)
        return edge_indices_list, edge_weights_list

    def forward(self, x: Tensor, edge_index: Adj, ordering: Tensor = None,
                edge_weight: OptTensor = None) -> Tensor:

        if isinstance(x, (tuple, list)):
            raise ValueError(f"'{self.__class__.__name__}' received a tuple "
                             f"of node features as input while this layer "
                             f"does not support bipartite message passing. "
                             f"Please try other layers such as 'SAGEConv' or "
                             f"'GraphConv' instead")
        if self.cached and self.edge_indices is None:
            self.edge_indices, self.edge_weights = self.init_graph(x, edge_index, ordering, edge_weight)
            edge_indices = self.edge_indices
            edge_weights = self.edge_weights
        else:
            edge_indices, edge_weights = self.init_graph(x, edge_index, ordering, edge_weight)

        n = x.size(0)
        n = (n,n)
        out = None

        for direction in range(len(edge_indices)):
            edge_indices_dir = edge_indices[direction]
            edge_index_cat = torch.cat(edge_indices_dir, dim=1)
            edge_weights_cat = torch.cat(edge_weights[direction], dim=0)
            x_dir = self.lins[direction](x)
            x_j = torch.cat((x_dir[:, :self.out_channels][edge_indices_dir[0][0]],
                         x_dir[:, self.out_channels:2*self.out_channels][edge_indices_dir[1][0]],
                         x_dir[:, 2*self.out_channels:][edge_indices_dir[2][0]]
                         ), dim=0)
            out = self.propagate(edge_index_cat, x_j=x_j, edge_weight=edge_weights_cat,
                             size=n)

        if self.bias is not None:
            out = out + self.bias

        return out

    def _collect(self, args, edge_index, size, kwargs):
        out = {'x_j': kwargs.get('x_j'),
               edge_index: edge_index,
               'size': size,
               'size_i': size[0],
               'size_j': size[0],
               'dim_size': size[0],
               'edge_weight': kwargs.get('edge_weight'),
               'index': edge_index[1]}
        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)
