import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym import cfg
import torch_geometric.graphgym.register as register
import torch
from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse
from typing import Optional
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils.sparse import set_sparse_value
from torch_geometric.utils import is_torch_sparse_tensor, to_edge_index, scatter

def gcn_norm(
    edge_index: Adj,
    edge_weight: OptTensor = None,
    num_nodes: Optional[int] = None,
    improved: bool = False,
    add_self_loops: bool = True,
    flow: str = "source_to_target",
    dtype: Optional[torch.dtype] = None,
):
    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        assert edge_index.size(0) == edge_index.size(1)

        adj_t = edge_index

        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = torch_sparse.fill_diag(adj_t, fill_value)

        deg = torch_sparse.sum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1))

        return adj_t

    if is_torch_sparse_tensor(edge_index):
        assert edge_index.size(0) == edge_index.size(1)

        if edge_index.layout == torch.sparse_csc:
            raise NotImplementedError("Sparse CSC matrices are not yet "
                                      "supported in 'gcn_norm'")

        adj_t = edge_index
        if add_self_loops:
            adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)

        edge_index, value = to_edge_index(adj_t)
        col, row = edge_index[0], edge_index[1]

        deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum')
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col]

        return set_sparse_value(adj_t, value), None

    assert flow in ['source_to_target', 'target_to_source']
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if add_self_loops:
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

    if edge_weight is None:
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                 device=edge_index.device)

    row, col = edge_index[0], edge_index[1]
    idx = col if flow == 'source_to_target' else row
    deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    return edge_index, edge_weight

class GCNConvLayer(nn.Module):
    """Graph Isomorphism Network with Edge features (GINE) layer.
    """
    def __init__(self, dim_in, dim_out, dropout, residual):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dropout = dropout
        self.residual = residual

        self.act = nn.Sequential(
            register.act_dict[cfg.gnn.act](),
            nn.Dropout(self.dropout),
        )
        self.dag = True
        self.model1 = pyg_nn.GCNConv(dim_in, dim_out, bias=not self.dag, normalize=not self.dag)
        self.model2 = pyg_nn.GCNConv(dim_in, dim_out, bias=False, normalize=False)
        self.model3 = pyg_nn.GCNConv(dim_in, dim_out, bias=True, normalize=False)
        self.alpha = 1/3
        print('DAG:',self.dag)

    def forward(self, batch):
        x_in = batch.x
        if self.dag:
            edge_indices = []
            edge_weights = []
            edge_index, edge_weight = gcn_norm(batch.edge_index, add_self_loops=False)
            node_degrees = batch.node_degrees_out#[batch.n_id]
            filter1 = node_degrees[edge_index[0]] < node_degrees[edge_index[1]]
            edge_indices.append(edge_index[:, filter1])
            edge_weights.append(edge_weight[filter1])
            filter2 = node_degrees[edge_index[0]] > node_degrees[edge_index[1]]
            edge_indices.append(edge_index[:, filter2])
            edge_weights.append(edge_weight[filter2])
            filter3 = node_degrees[edge_index[0]] == node_degrees[edge_index[1]]
            edge_indices.append(edge_index[:, filter3])
            edge_weights.append(edge_weight[filter3])            
    
            batch.x = self.model1(batch.x, edge_indices[0], edge_weights[0]) + self.model2(batch.x, edge_indices[1], edge_weights[1]) + self.model3(batch.x, edge_indices[2], edge_weights[2])
        else:
            batch.x = self.model1(batch.x, batch.edge_index)
        batch.x = self.act(batch.x)

        if self.residual:
            batch.x = x_in + batch.x  # residual connection

        return batch
