import numpy as np

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

from grakel.kernels import WeisfeilerLehman
from layers import * 
import pytorch_lightning as pl
    

H = lambda x: -torch.sum(x*x.log(),-1)
JSD = lambda x: H(x.mean(0))-H(x).mean(0)

import scipy
from grakel.kernels import WeisfeilerLehman, VertexHistogram, WeisfeilerLehmanOptimalAssignment, Propagation, GraphletSampling, RandomWalkLabeled, PyramidMatch
import time
from scipy import sparse
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

from layers import VectorQuantizerEMA

def one_hot_embedding(labels,nlabels):
    eye = torch.eye(nlabels)
    return eye[labels] 

def max_comp(E,d):
    E = list(E)
    
    if len(E)==0:
        return E, d
        
    graph = csr_matrix((np.ones(len(E)), zip(*E)),[np.max(E)+1]*2)
    n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)
    (unique, counts) = np.unique(labels, return_counts=True)
    max_elms = np.argwhere(labels==unique[np.argmax(counts)])
    
    max_ed_list = [e for e in E if (e[0] in max_elms) and (e[1] in max_elms)]

    dnew =dict([((int(k),d[k])) for k in max_elms.flatten()])

    return max_ed_list, dnew    

class GKernel(nn.Module):
    def __init__(self, nodes, labels, filters = 8, max_cc=None, hops=3, kernels='wl', normalize=True, store_fit=False):
        super(GKernel, self).__init__()
        self.hops=hops

        A = torch.from_numpy(np.random.rand(filters,nodes,nodes)).float()
        A = ((A+A.transpose(-2,-1))>1).float()
        A = torch.stack([a-torch.diag(torch.diag(a)) for a in A],0)
        self.P = nn.Parameter(A,  requires_grad=False)

        self.X = nn.Parameter(torch.stack([one_hot_embedding(torch.randint(labels,(nodes,)),labels) for fi in range(filters)],0), requires_grad=False)
        self.Xp = nn.Parameter(torch.zeros((filters,nodes,labels)).float(),  requires_grad=True)
        
        self.Padd = nn.Parameter(torch.randn(filters,nodes,nodes)*0)
        self.Prem = nn.Parameter(torch.randn(filters,nodes,nodes)*0)
        self.Padd.data = self.Padd.data + self.Padd.data.transpose(-2,-1)
        self.Prem.data = self.Prem.data + self.Prem.data.transpose(-2,-1)
        
            
        self.filters = filters
        self.store=[None]*filters
        
        self.gks = []
        for kernel in kernels.split('+'):
            if kernel=='wl':
              self.gks.append(lambda x : WeisfeilerLehman(n_iter=3, normalize=normalize))
            if kernel=='wloa':
              self.gks.append(lambda x : WeisfeilerLehmanOptimalAssignment(n_iter=3, normalize=normalize))
            if kernel=='prop':
              self.gks.append(lambda x : Propagation(normalize=normalize))
            if kernel=='rw':
              self.gks.append(lambda x : RandomWalkLabeled(normalize=normalize))
            if kernel=='gl':
              self.gks.append(lambda x : GraphletSampling(normalize=normalize))
            if kernel=='py':
              self.gks.append(lambda x : PyramidMatch(normalize=normalize))
           
        self.store_fit = store_fit
        self.stored = False          
              
    def forward(self, x, edge_index, not_used=None, fixedges=None, node_indexes=[]):           

        convs = []
        for gk in self.gks:
            convs.append( GKernelConv.apply(x, edge_index, self.P, self.Padd, self.Prem, self.X, self.Xp, self.hops, self.training, gk(None), self.stored, node_indexes))
        conv = torch.cat(convs,-1)
        return conv



def get_egonets(x,edge_index,i, hops=3):
    fn,fe,_,_ = torch_geometric.utils.k_hop_subgraph([i],num_hops=hops,edge_index=edge_index)
    node_map = torch.arange(fn.max()+1)
    node_map[fn] = torch.arange(fn.shape[0])
    ego_edges = node_map[fe]
    ego_nodes = x[fn,:]
    return ego_nodes,ego_edges

class GKernelConv(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, edge_index, P, Padd, Prem, X, Xp, hops, training, gk, stored, node_indexes):
        #graph similarity here
        filters = P.shape[0]
        convs = []
        
        
        if not stored: 
          egonets = [get_egonets(x,edge_index,i, hops) for i in  torch.arange(x.shape[0])]
          G1 = lambda i: [set([ (e[0],e[1]) for e in egonets[i][1].t().numpy()]),
                           dict(zip(range(egonets[i][0].shape[0]),egonets[i][0].argmax(-1).numpy()))]
          Gs1 = [G1(i) for i in range(x.shape[0])]
          
          conv = GKernelConv.eval_kernel(x, Gs1,P,X, gk, False)
        else:
          conv = GKernelConv.eval_kernel(None, None,P,X, gk, True)[node_indexes,:]
          Gs1 = None
                    
        ctx.save_for_backward(x, edge_index, P, Padd, Prem, X, Xp, conv)
        ctx.stored = stored
        ctx.node_indexes = node_indexes
        ctx.Gs1 = Gs1
        ctx.P = P
        ctx.X = X
        ctx.gk = gk
        
        return conv.float()
    
    @staticmethod
    def backward(ctx, grad_output):
        x, edge_index, P, Padd, Prem, X, Xp, conv = ctx.saved_tensors
        P=ctx.P
#         grad_input = grad_weight = grad_bias = None

        #grad_input -> kernel response gradient size: filters x nodes
        #todo: estimate gradient and keep the one maximizing dot product
        
        #perform random edit for each non zero filter gradient:
        grad_padd = 0
        grad_prem = 0
        grad_xp = 0
        
        kindexes = torch.nonzero(torch.norm(grad_output,dim=0))[:,0]
        Pnew = P.clone()
        Xnew = X.clone()
       
        for i in range(3): #if the gradient of the edit w.r.t. the loss is 0 we try another edit operation 
            for fi in kindexes:
                edit_graph = torch.rand((1,)).item()<0.5 or X.shape[-1]==1
                Pnew,Xnew = GKernelConv.random_edit(fi,Pnew,Padd,Prem,Xnew,Xp,edit_graph)
            if not ctx.stored:
              convnew = GKernelConv.eval_kernel( x, ctx.Gs1, Pnew, Xnew, ctx.gk, True)    
            else:
              convnew = GKernelConv.eval_kernel( None, None, Pnew, Xnew, ctx.gk, True)[ctx.node_indexes,:]  
            grad_fi = conv-convnew

            proj = (grad_fi*grad_output).sum(0)[:,None,None]
            kindexes = kindexes[proj[kindexes,0,0]==0]
            if len(kindexes) == 0:
                break

        grad_padd += proj*(P-Pnew)
        grad_prem += proj*(Pnew-P)
        grad_xp += proj*(X-Xnew)

        th=0
        ctx.P.data = (proj>=th)*Pnew + (proj<th)*P
        ctx.X.data = (proj>=th)*Xnew + (proj<th)*X

        return None, None, None, grad_padd*((Padd).sigmoid()*(1-(Padd).sigmoid())),\
                                 grad_prem*((Prem).sigmoid()*(1-(Prem).sigmoid())), None,\
                                 grad_xp*(Xp.sigmoid()*(1-Xp.sigmoid())), None, None, None, None, None
    
    @staticmethod
    def eval_kernel( x, Gs1, P, X, gk, stored=False):
        filters = P.shape[0]
        nodes = P.shape[1]
        
        Gs2 = [max_comp(set([ (e[0],e[1]) for e in torch_geometric.utils.dense_to_sparse(P[fi])[0].t().numpy()]),
                 dict(zip(range(nodes),X[fi].argmax(-1).flatten().detach().numpy()))) for fi in  range(filters)]

        
        if not stored:
          gk.fit(Gs1)
          sim = gk.transform(Gs2)
          sim = np.nan_to_num(sim)
        else:
          sim = gk.transform(Gs2)
          sim = np.nan_to_num(sim)
                            
        return torch.from_numpy(sim.T)
        
    @staticmethod
    def random_edit(i, Pin,Padd,Prem,X,Xp,edit_graph, n_edits=1, temp=0.1):
        filters = Pin.shape[0]
        
        P = Pin.clone()
        X = X.clone()
        if edit_graph: #edit graph
            Pmat = P[i]*(Prem[i]*temp).sigmoid().data + (1-P[i])*(Padd[i]*temp).sigmoid().data + 1e-8#sample edits
            Pmat = Pmat * (1-np.eye(Pmat.shape[-1]))
            Pmat = Pmat/Pmat.sum()
            inds = np.random.choice(Pmat.shape[0]**2,size=(n_edits,),replace=False,p=Pmat.flatten().numpy(),)
            inds = torch.from_numpy(np.stack(np.unravel_index(inds,Pmat.shape),0)).to(Pmat.device)

            inds = torch.cat([inds,inds[[1,0],:]],-1) #symmetric edit
            P[i].data[inds[0],inds[1]] = 1-P[i].data[inds[0],inds[1]]
        
            if(P[i].sum()==0): #avoid fully disconnected graphs
                P = Pin.clone()
        else: #edit labels
            PX = (Xp[i]*temp).softmax(-1).data
            pi = 1-PX.max(-1)[0]
            pi = pi/pi.sum(-1,keepdims=True)
    
            lab_ind = np.random.choice(X[i].shape[0],(n_edits,),p=pi.numpy())
            lab_val = [np.random.choice(PX.shape[1],size=(1,),replace=False,p=PX[j,:].numpy(),) for j in lab_ind]

            X[i].data[lab_ind,:] = 0
            X[i].data[lab_ind,lab_val] = 1
            
        return P,X
    
    
class Model(pl.LightningModule):
    def __init__(self, hparams):
        super(Model, self).__init__()
        #self.hparams=hparams
        
        if(type(hparams) is dict):
            import argparse
            args = argparse.ArgumentParser()
            for k in hparams.keys():
                args.add_argument('--'+k,default=hparams[k])
            hparams = args.parse_args([])
            
        #self.hparams=hparams
        if not 'activation' in hparams.__dict__:
            hparams.activation='relu'
            
        self.save_hyperparameters(hparams)
        
        in_features, hidden, num_classes, labels = hparams.in_features,  hparams.hidden,  hparams.num_classes, hparams.labels
        
#         assert(layers==1) #only works with 1 layer at the moment
        self.conv_layers = nn.ModuleList()
        self.vq_layers =nn.ModuleList()
    
        self.conv_layers.append(GKernel(hparams.nodes,labels,hidden,max_cc=self.hparams.max_cc,hops=hparams.hops, kernels=hparams.kernel, store_fit=True))
        
        n_kernels = len(hparams.kernel.split('+'))
        for i in range(1, hparams.layers):
            self.conv_layers.append(GKernel(hparams.nodes,hidden,hidden,max_cc=self.hparams.max_cc,hops=hparams.hops, kernels=hparams.kernel))
            
            commitment_cost = 0.25
            decay = 0.99
            self.vq_layers.append(VectorQuantizerEMA(hidden*n_kernels, hidden, commitment_cost,decay=decay))
        
        activation = nn.ReLU
        if hparams.activation == 'sigmoid':
          activation = nn.Sigmoid
          
        self.fc = nn.Sequential(nn.Linear(hidden*n_kernels*hparams.layers,hidden),activation(),nn.Linear(hidden,hidden),activation(),nn.Linear(hidden,num_classes))
        
        self.eye = torch.eye(hidden)
        self.lin = nn.Linear(hidden,hidden)
        
        self.automatic_optimization = False
        
        def _regularizers(x):
            jsdiv = hparams.jsd_weight*JSD(x.softmax(-1))
            return -jsdiv
        
        self.regularizers = _regularizers
            
        self.mask=nn.Parameter(torch.ones(hidden).float(),requires_grad=False)
        
    def one_hot_embedding(self, labels):
        self.eye = self.eye.to(labels.device)
        return self.eye[labels] 

    def forward(self, data):
        if 'nidx' not in data.__dict__:
            data.nidx = None
    
        batch=data.batch
        edge_index=data.edge_index
        x=data.x
        
        loss = x.sum().detach()*0
        
        responses = []
        for l,vq in zip(self.conv_layers,[None]+list(self.vq_layers)): #only works with one layer
            if vq!=None:
                vqloss, x, qidx, perplexity, _ = vq(x)
                loss = loss+vqloss
                x = self.one_hot_embedding(qidx)

            x = l(x,edge_index,node_indexes=data.nidx)
            if self.mask is not None:
                x = x*self.mask[None,:].repeat(1,x.shape[-1]//self.mask.shape[-1])
            
            responses.append(x)
        x = torch.cat(responses,-1)
        
        pooling_op = None
        if self.hparams.pooling=='add': 
            pooling_op=global_add_pool
        if self.hparams.pooling=='max': 
            pooling_op=global_max_pool
        if self.hparams.pooling=='mean': 
            pooling_op=global_mean_pool
        
        return self.fc(pooling_op(x,batch)), responses, loss
    
    def configure_optimizers(self):
        graph_params = set(self.conv_layers.parameters())
        cla_params = set(self.parameters())-graph_params
        optimizer = torch.optim.Adam([{'params': list(graph_params),'lr': self.hparams.lr_graph},\
                                      {'params': list(cla_params),'lr': self.hparams.lr}])
        
        return optimizer
    
    
    def training_step(self, train_batch, batch_idx):  
        
        data=train_batch
        
        optimizer = self.optimizers()
        
        optimizer.zero_grad()
        output, responses, _ = self(data)
        loss_ce = torch.nn.functional.cross_entropy(output, data.y)
        
        loss_jsd = torch.stack([self.regularizers(x) for x in responses]).mean()
        
        loss = loss_ce + loss_jsd 
        loss.backward()
        optimizer.step()
            
        acc = 100*torch.mean( (output.argmax(-1)==data.y).float()).detach().cpu()
        self.log('acc', acc, on_step=False, on_epoch=True)
        self.log('loss', loss.item(), on_step=False, on_epoch=True)
        self.log('loss_jsd', loss_jsd.item(), on_step=False, on_epoch=True)
        self.log('loss_ce', loss_ce.item(), on_step=False, on_epoch=True)

    def validation_step(self, train_batch, batch_idx):
        data=clean_graph(train_batch)
        with torch.no_grad():
            output, x1, _ = self(data)
            loss = torch.nn.functional.cross_entropy(output, data.y)
            acc = 100*torch.mean( (output.argmax(-1)==data.y).float()).detach().cpu()
            self.log('val_loss', loss.item(), on_step=False, on_epoch=True)
            self.log('val_acc', acc, on_step=False, on_epoch=True)
 
    def test_step(self, train_batch, batch_idx):
        data=clean_graph(train_batch)
        
        with torch.no_grad():
            output, x1, _ = self(data)
            loss = torch.nn.functional.cross_entropy(output, data.y)*output.shape[0]/self.hparams.batch_size
            
            acc = 100*torch.mean( (output.argmax(-1)==data.y).float()).detach().cpu()
            self.log('test_loss', loss.item(), on_step=False, on_epoch=True)
            self.log('test_acc', acc, on_step=False, on_epoch=True)       

def clean_graph(data):
    connected_nodes = set([i.item() for i in data.edge_index.flatten()])
    isolated_nodes = [i for i in range(data.x.shape[0]) if i not in connected_nodes]
    mask = torch.ones((data.x.shape[0],)).bool()
    mask[isolated_nodes] = False
    mapping = -torch.ones((data.x.shape[0],)).long()
    mapping[mask] = torch.arange(mask.sum())

    data.edge_index = mapping[data.edge_index]
    if 'nidx' in data.__dict__:
        data.nidx = data.nidx[mask]
    data.x = data.x[mask,:].int()
    data.batch = data.batch[mask]
    return data