import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch import tensor
import torch.nn as nn
from torch_geometric.utils import sparse as sp
from torch_geometric.data import Data
from torch.nn import Parameter
from scattering.utils_scat import GCN_diffusion, scattering_diffusion, sparse_mx_to_torch_sparse_tensor
import numpy as np
from torch_geometric.utils.convert import to_scipy_sparse_matrix
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csr_matrix
from torch.nn.modules.module import Module
class GC(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    https://github.com/tkipf/pygcn/blob/master/pygcn/models.py
    """
    def __init__(self, in_features, out_features):
        super(GC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.mlp = nn.Linear(in_features, out_features)
    def forward(self, input, adj,device='cuda'):
        # adj is extracted from the graph structure
        support = self.mlp(input)

        I_n = sp.eye(adj.size(0))
        I_n = sparse_mx_to_torch_sparse_tensor(I_n).to(device)
        A_gcn = adj +  I_n
        degrees = torch.sparse.sum(A_gcn,0)
        D = degrees
        D = D.to_dense() # transfer D from sparse tensor to normal torch tensor
        D = torch.pow(D, -0.5)
        D = D.unsqueeze(dim=1)
        A_gcn_feature = support
        A_gcn_feature = torch.mul(A_gcn_feature,D)
        A_gcn_feature = torch.spmm(A_gcn,A_gcn_feature)
        output = torch.mul(A_gcn_feature,D)
        return output



class GC_withres(Module):
    """
    res conv
    """
    def __init__(self, in_features, out_features,smooth):
        super(GC_withres, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.smooth = smooth
        self.mlp = nn.Linear(in_features, out_features)
    def forward(self, input, adj,device='cuda'):
        support = self.mlp(input)
        I_n = sp.eye(adj.size(0))
        I_n = sparse_mx_to_torch_sparse_tensor(I_n).to(device)
        A_gcn = adj +  I_n
        degrees = torch.sparse.sum(A_gcn,0)
        D = degrees
        D = D.to_dense() # transfer D from sparse tensor to normal torch tensor
        D = torch.pow(D, -0.5)
        D = D.unsqueeze(dim=1)
        A_gcn_feature = support
        A_gcn_feature = torch.mul(A_gcn_feature,D)
        A_gcn_feature = torch.spmm(A_gcn,A_gcn_feature)
        A_gcn_feature = torch.mul(A_gcn_feature,D)
        output = A_gcn_feature * self.smooth + support
        output = output/(1+self.smooth)
        return output


class SCTConv(torch.nn.Module):
    def __init__(self, hidden_dim, smooth, dropout,Withgres=False):
        super().__init__()
        self.hid = hidden_dim
        self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.a = Parameter(torch.zeros(size=(2*hidden_dim, 1)))
        self.smoothlayer = Withgres #turn on graph residual layer or not
        self.gres = GC_withres(hidden_dim,hidden_dim,smooth = smooth)
        self.dropout = dropout
    def forward(self,X,adj,moment = 1,device = 'cuda'):
        """
        Params
        ------
        adj [batch x nodes x nodes]: adjacency matrix
        X [batch x nodes x features]: node features matrix
        Returns
        -------
        X' [batch x nodes x features]: updated node features matrix
        """
        support0 = X
        N = support0.size()[0]
        h = support0
        gcn_diffusion_list = GCN_diffusion(adj,3,support0,device=device)
        h_A =  gcn_diffusion_list[0]


        h_A = nn.LeakyReLU()(h_A)


        h_sct1 = scattering_diffusion(adj,support0)
        h_sct1 = torch.abs(h_sct1)**moment




        a_input_A = torch.hstack((h, h_A)).unsqueeze(1)

        a_input_sct1 = torch.hstack((h, h_sct1)).unsqueeze(1)


        # a_input =  torch.cat((a_input_A,a_input_A2,a_input_A3,a_input_sct1,a_input_sct2,a_input_sct3),1).view(N,6,-1)
        a_input = torch.cat((a_input_A, a_input_sct1), 1).view(N, 2, -1)
        
        #GATV2
        e = torch.matmul(torch.nn.functional.relu(a_input),self.a).squeeze(2)
        attention = F.softmax(e, dim=1).view(N, 2, -1)
#        h_all = torch.cat((h_A.unsqueeze(dim=2),h_A2.unsqueeze(dim=2),h_A3.unsqueeze(dim=2),h_sct1.unsqueeze(dim=2),h_sct2.unsqueeze(dim=2),h_sct3.unsqueeze(dim=2)),dim=2).view(N,6, -1)
        # h_all = torch.cat((h_A.unsqueeze(dim=1), h_A2.unsqueeze(dim=1),h_A3.unsqueeze(dim=1), h_sct1.unsqueeze(dim=1), h_sct2.unsqueeze(dim=1), h_sct3.unsqueeze(dim=1)),dim=1)
        h_all = torch.cat((h_A.unsqueeze(dim=1), h_sct1.unsqueeze(dim=1)), dim=1)
        h_prime = torch.mul(attention, h_all) # element eise product
        h_prime = torch.mean(h_prime,1)
        if self.smoothlayer:
            h_prime = self.gres(h_prime,adj,device)
        else:
            pass
        X = self.linear1(h_prime)
        X = F.leaky_relu(X)
        X = self.linear2(X)
        X = F.leaky_relu(X)
        return X

class Scat(nn.Module):
    def __init__(self, conf, gmn_config):
        super().__init__()
        self.conf = conf
        self.dropout = conf.training.dropout
        self.smooth = conf.model.smooth
        self.in_proj = torch.nn.Linear(conf.model.input_dim, conf.model.hidden_dim)
        self.convs = torch.nn.ModuleList()
        for _ in range(conf.model.n_layers):
            self.convs.append(SCTConv(conf.model.hidden_dim, self.smooth, self.dropout, conf.model.use_smoo))
        self.mlp1 = Linear(conf.model.hidden_dim*(1+conf.model.n_layers), conf.model.hidden_dim)
        self.mlp2 = Linear(conf.model.hidden_dim, conf.model.output_dim)
        self.device = conf.training.device

    def forward(self, batch_data, batch_data_sizes, batch_adj):
        # This inherently uses single element as a batch -> there doesn't exist a batch dimension. Naming is done to preserve consistency
        if not torch.is_tensor(batch_data.x):
            features = torch.FloatTensor(batch_data.x).to(self.device)
        else:
            features = batch_data.x.to(self.device)
        edge_index = batch_data.edge_index
        adj = to_scipy_sparse_matrix(edge_index)
        adjmatrix = sparse_mx_to_torch_sparse_tensor(adj).to(self.device)
        scale = np.sqrt(batch_data_sizes)
        features = features / scale
        features = self.in_proj(features)
        hidden_states = features
        for layer in self.convs:
            features = layer(features, adjmatrix, moment=self.conf.model.moment, device=self.device)
            features = features / scale
            hidden_states = torch.cat([hidden_states, features], dim=1)
        features = hidden_states
        features = self.mlp1(features)
        features = F.leaky_relu(features)
        features = self.mlp2(features)
        maxval = torch.max(features)
        minval = torch.min(features)
        probs = (features - minval) / (maxval + 1e-6 - minval)
        if self.training:
            return probs, adjmatrix
        else:
            return probs, adjmatrix, adj
    
    def compute_loss(self, adjmatrix, features):
        I_n = sp.eye(adjmatrix.size(0))
        I_n = sparse_mx_to_torch_sparse_tensor(I_n).to(self.device)
        Fullm = torch.ones(I_n.size(0),I_n.size(1)).to(self.device) - I_n #(N,N) 
        diffusionprob = torch.mm(Fullm - adjmatrix,features)
        elewiseloss = features * diffusionprob
        lossComplE = self.conf.model.penalty_coefficient * torch.sum(elewiseloss) # loss on compl of Edges
        lossE = torch.sum(features*torch.mm(adjmatrix,features))
        loss = -lossE + lossComplE
        return loss
    

class Scat_Feat(Scat):
    def __init__(self, conf, gmn_config):
        super().__init__(conf, gmn_config)
        