import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool

from .layers import ImplicitGraph
from ..tools.utils import get_spectral_rad
from ..functions import ReLU

"""
------------------
Models Package
------------------
Models implementation for experiments given in `mono_ignn/experiments`

    AmazonIGNN:

    ChainsIGNN:

    GraphIGNN:

    ppiIGNN:

------------------
Reference
------------------
https://github.com/SwiftieH/IGNN

"""

"""
IGNN - Single layer IGNN
"""
class IGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(IGNN, self).__init__()
        self.adj = adj
        if adj is not None:
            self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.dropout = dropout

        if num_node is not None:
            self.X_0 = nn.Parameter(torch.zeros(nhid, num_node, device=device))
            # self.X_0 = torch.zeros(nhid, num_node, device=device)
        else:
            self.X_0 = None
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features
        x = self.ig1(self.X_0, self.adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        # x = self.ig1(self.X_0, self.adj, x, nn.Identity(), self.adj_rho, A_orig=self.adj_orig).T
        x = F.normalize(x, dim=-1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return  x, None


"""
IGNN - Single layer IGNN
"""
class CitationIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, L, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(CitationIGNN, self).__init__()
        self.adj = adj
        if adj is not None:
            self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.L = L
        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.dropout = dropout

        self.X_0 = nn.Parameter(torch.zeros(nhid, num_node, device=device))
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features
        x = self.ig1(self.X_0, self.adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        smooth = x.t()@self.L@x
        print('Smooth', *(val.item() for val in torch.diag(smooth).sort()[0]),sep=', ')
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return  F.log_softmax(x, dim=1), None


"""
AmazonIGNN - Model for amazon co-purchasing node classification task.
"""
class AmazonIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(AmazonIGNN, self).__init__()

        self.adj = adj
        if adj is not None:
            self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.dropout = dropout
        self.X_0 = nn.Parameter(torch.zeros(nhid, num_node, device=device))
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features
        x = self.ig1(self.X_0, self.adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return  x, None

"""
ChainsIGNN - Model for synthetic node classification task.
"""
class ChainsIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(ChainsIGNN, self).__init__()

        self.adj = adj
        self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.dropout = dropout
        self.X_0 = nn.Parameter(torch.zeros(nhid, num_node, device=device))
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features
        x = self.ig1(self.X_0, self.adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        x = F.normalize(x, dim=-1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return  x, None

"""
GraphClassificationIGNN - Single layer IGNN
"""
class GraphClassificationIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, kappa=0.9, adj_orig=None):
        super(GraphClassificationIGNN, self).__init__()

        self.adj = None
        self.adj_rho = None
        self.adj_orig = adj_orig

        #three layers and two MLP
        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa)
        self.ig2 = ImplicitGraph(nhid, nhid, num_node, kappa)
        self.ig3 = ImplicitGraph(nhid, nhid, num_node, kappa)
        self.dropout = dropout
        self.X_0 = None
        self.V_0 = nn.Linear(nhid, nhid)
        self.V_1 = nn.Linear(nhid, nclass)

    def forward(self, features, adj, batch):
        '''
        if adj is not self.adj:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        '''
        self.adj_rho = 1

        x = features

        #three layers and two MLP
        x = self.ig1(self.X_0, adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig)
        x = self.ig2(self.X_0, adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig)
        x = self.ig3(self.X_0, adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        x = global_add_pool(x, batch)
        x = F.relu(self.V_0(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V_1(x)
        return F.log_softmax(x, dim=1), None

"""
HeteroIGNN - Single layer IGNN
"""
class HeteroIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(HeteroIGNN, self).__init__()
        self.adj = adj
        if adj is not None:
            self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.ig1 = ImplicitGraph(nfeat, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.dropout = dropout

        self.X_0 = nn.Parameter(torch.zeros(nhid, num_node, device=device), requires_grad=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features
        x = self.ig1(self.X_0, self.adj, x, ReLU(), self.adj_rho, A_orig=self.adj_orig).T
        return  x, None

"""
ppiIGNN - Model for PPI node classification task.
"""
class ppiIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, kappa=0.9, adj_orig=None, device='cpu', **kwargs):
        super(ppiIGNN, self).__init__()

        self.adj = adj
        if adj is not None:
            self.adj_rho = get_spectral_rad(adj)
        self.adj_orig = adj_orig

        self.ig1 = ImplicitGraph(nfeat, 4*nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.ig2 = ImplicitGraph(4*nhid, 2*nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.ig3 = ImplicitGraph(2*nhid, 2*nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.ig4 = ImplicitGraph(2*nhid, nhid, num_node, kappa=kappa, device=device, **kwargs)
        self.ig5 = ImplicitGraph(nhid, nclass, num_node, kappa=kappa, device=device, **kwargs)

        self.dropout = dropout
        self.X_0 = None

        self.V = nn.Linear(nhid, nclass, bias=False)
        self.V_0 = nn.Linear(nfeat, 4*nhid, bias=False)
        self.V_1 = nn.Linear(4*nhid, 2*nhid, bias=False)
        self.V_2 = nn.Linear(2*nhid, 2*nhid, bias=False)
        self.V_3 = nn.Linear(2*nhid, nhid, bias=False)
        pass

    def forward(self, features, *args, adj=None, compute_jac_loss=False, **kwargs):
        if adj is not self.adj and adj is not None:
            self.adj = adj
            self.adj_rho = get_spectral_rad(adj)
        x = features

        x = F.elu(self.ig1(self.X_0, adj, x, F.relu, self.adj_rho, A_orig=self.adj_orig).T + self.V_0(x.T)).T
        x = F.elu(self.ig2(self.X_0, adj, x, F.relu, self.adj_rho, A_orig=self.adj_orig).T + self.V_1(x.T)).T
        x = F.elu(self.ig3(self.X_0, adj, x, F.relu, self.adj_rho, A_orig=self.adj_orig).T + self.V_2(x.T)).T
        x = F.elu(self.ig4(self.X_0, adj, x, F.relu, self.adj_rho, A_orig=self.adj_orig).T + self.V_3(x.T)).T
        x = self.ig5(self.X_0, adj, x, F.relu, self.adj_rho, A_orig=self.adj_orig).T + self.V(x.T)
        #return F.log_softmax(x, dim=1)
        return x
