from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_sparse import SparseTensor, matmul


class ASGNN(MessagePassing):
    _cached_edge_index: Optional[Tuple[Tensor, Tensor]]
    _cached_adj_t: Optional[SparseTensor]

    def __init__(self, args, cached: bool = False, add_self_loops: bool = True,
                 normalize: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(ASGNN, self).__init__(**kwargs)
        for k, v in vars(args).items():
            setattr(self, k, v)
        self.K = self.num_layers
        self.cached = self.transductive
        self.add_self_loops = add_self_loops
        self.normalize = normalize
        self._cached_edge_index = None
        self._cached_adj_t = None
        
        self.input_trans = torch.nn.Linear(self.num_features, self.dim_hidden)
        self.output_trans = torch.nn.Linear(self.dim_hidden, self.num_classes)
        
        self.alpha = torch.nn.Parameter(torch.tensor(1e-6))
        self.beta = torch.nn.Parameter(torch.tensor(1e-6))
        self.gamma = torch.nn.Parameter(torch.tensor(1e-6))
        self.lamb = torch.nn.Parameter(torch.tensor(9.))
        self.eta1 = torch.nn.Parameter(torch.tensor(0.05))
        self.eta2 = torch.nn.Parameter(torch.tensor(1e-6))  
        
        self.optimizer = torch.optim.Adam(self.parameters(),lr=self.lr,weight_decay=self.weight_decay)
        
    def reset_parameters(self):
        self._cached_edge_index = None
        self._cached_adj_t = None   

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(0), False,
                        self.add_self_loops, dtype=x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(0), False,
                        self.add_self_loops, dtype=x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache
                    
                    
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.input_trans(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        hh = self.output_trans(x)
       
        x = hh
        alpha = self.alpha
        beta = self.beta
        gamma = self.gamma
        lamb = self.lamb
        eta1 = self.eta1
        eta2 = self.eta2
        S = torch.zeros([x.size()[0],x.size()[0]]).to(self.device)
        S[edge_index[0], edge_index[1]] = 1   
        
        ones = torch.ones(S.size()[0]).T.to(self.device)
        
        idx = torch.nonzero(S).T
        data = S[idx[0], idx[1]]
        S = torch.sparse_coo_tensor(idx,data,S.shape)  
        A = S
        for k in range(self.K):      
            rowsum = torch.sparse.sum(S,1,torch.float32).to_dense()
            
            r_inv = rowsum.pow(-1).flatten()
            r_inv[torch.isinf(r_inv)] = 0.
            r_mat_inv = torch.diag(r_inv)

            x = (1 - 2 * eta1 * lamb - 2 * eta1) * x + 2 * eta1 * hh + 2 * eta1 * lamb * r_mat_inv @ torch.sparse.mm(S, x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
            x_normalize = torch.mm(x, x.T) @ r_mat_inv
            S = (sparse_scaling(S, 1 - 2 * eta2 * gamma - 2 * eta2 * beta) +  sparse_scaling(A, 2 * eta2 * gamma)).to_dense() \
                +  eta2 * lamb * x_normalize.T - eta2 * lamb * torch.diag(r_mat_inv @ torch.sparse.mm(S ,x_normalize)) * ones

            S = torch.clamp(S-eta1*alpha,min=0,max=1)
            idx = torch.nonzero(S).T
            data = S[idx[0], idx[1]]
            S = torch.sparse_coo_tensor(idx,data,S.shape)  
        return F.log_softmax(x, dim=1)

                   
        
    
            
    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
        return edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        return '{}(K={}, alpha={})'.format(self.__class__.__name__, self.K,
                                           self.alpha)
def sparse_scaling(S, s):
    S_dense = S.to_dense()
    idx = torch.nonzero(S_dense).T
    S_dense[idx[0], idx[1]] = S_dense[idx[0], idx[1]] * s
    return torch.sparse_coo_tensor(idx,S_dense[idx[0], idx[1]],S_dense.shape)  