import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool

from .functions import *
from .layers import MonotoneImplicitGraph
from .solvers import *
from .tools.normalization import AugNorm, LaplaceNorm

class MIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, disable_norm, dropout, adj, sp_adj,
                linModule, fpMethod, invMethod, device,
                *args, adj_pow=1, **kwargs):
        super(MIGNN, self).__init__()
            
        monotone_args = (nfeat, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        # Initialize
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)
        self.dropout = dropout
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, compute_jac_loss=False, **kwargs):

        if 'adj' in kwargs: # some graphs may update the adj matrices
            self.ig1.lin_module.set_adj(kwargs['adj'],sp_adj=None)
            if isinstance(self.ig1.lin_module,ProjectedLinear):
                adj_rho = get_spectral_rad(kwargs['adj'])
                self.ig1.lin_module.adj_rho = adj_rho

        if isinstance(self.ig1.lin_module,ProjectedLinear):
            self.ig1.lin_module.C = projection_norm_inf(self.ig1.lin_module.C, kappa=self.ig1.lin_module.kappa/self.ig1.lin_module.adj_rho)
        x = features
        x, jac_loss = self.ig1(x, compute_jac_loss=compute_jac_loss)
        x = x.T
        if not self.disable_norm:
            x = F.normalize(x, dim=-1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return x, jac_loss


"""
AmazonIGNN - Model for amazon co-purchasing node classification task.
"""
class AmazonMonIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, sp_adj,
                linModule, fpMethod, invMethod, device,
                *args, **kwargs):
        super(AmazonMonIGNN, self).__init__()

        monotone_args = (nfeat, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        self.dropout = dropout
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, compute_jac_loss=False, **kwargs):
        # if 'adj' in kwargs: # some graphs may update the adj matrices
        #     self.ig1.lin_module.set_adj(kwargs['adj'],sp_adj=None)
        #     if isinstance(self.ig1.lin_module,ProjectedLinear):
        #         adj_rho = get_spectral_rad(kwargs['adj'])
        #         self.ig1.lin_module.adj_rho = adj_rho
        # if isinstance(self.ig1.lin_module,ProjectedLinear):
        #     self.ig1.lin_module.C = projection_norm_inf(self.ig1.lin_module.C, kappa=self.ig1.lin_module.kappa/self.ig1.lin_module.adj_rho)

        x = features
        x, jac_loss = self.ig1(x, compute_jac_loss=compute_jac_loss)
        x = x.T
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return x, jac_loss

"""
ChainsIGNN - Model for synthetic node classification task.
"""
class ChainsMonIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, sp_adj,
                linModule, fpMethod, invMethod, device,
                *args, **kwargs):
        super(ChainsMonIGNN, self).__init__()

        monotone_args = (nfeat, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, nonlin='relu', **kwargs)
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        self.dropout = dropout
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, compute_jac_loss=False, **kwargs):
        if 'adj' in kwargs: # some graphs may update the adj matrices
            self.ig1.lin_module.set_adj(kwargs['adj'],sp_adj=None)
            if isinstance(self.ig1.lin_module,ProjectedLinear):
                adj_rho = get_spectral_rad(kwargs['adj'])
                self.ig1.lin_module.adj_rho = adj_rho
        if isinstance(self.ig1.lin_module,ProjectedLinear):
            self.ig1.lin_module.C = projection_norm_inf(self.ig1.lin_module.C, kappa=self.ig1.lin_module.kappa/self.ig1.lin_module.adj_rho)

        x = features
        x, jac_loss = self.ig1(x, compute_jac_loss=compute_jac_loss)
        x = x.t()
        x = self.V(x)
        return x, jac_loss

"""
CitationsMonIGNN - Model for synthetic node classification task.
"""
class CitationMonIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, L,
                linModule, fpMethod, invMethod, device,
                *args, adj_pow=1, **kwargs):
        super(CitationMonIGNN, self).__init__()

        adj = (adj.to_dense() + sum([torch.linalg.matrix_power(adj.to_dense(),i+1) for i in range(1,adj_pow)])).to_sparse()/adj_pow
        self.L = L

        monotone_args = (nfeat, nhid, num_node, adj, None, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        self.dropout = dropout
        self.V = nn.Linear(nhid, nclass, bias=False)
        pass

    def forward(self, features, *args, compute_jac_loss=False, **kwargs):
        if 'adj' in kwargs: # some graphs may update the adj matrices
            self.ig1.lin_module.set_adj(kwargs['adj'],sp_adj=None)
            if isinstance(self.ig1.lin_module,ProjectedLinear):
                adj_rho = get_spectral_rad(kwargs['adj'])
                self.ig1.lin_module.adj_rho = adj_rho
        if isinstance(self.ig1.lin_module,ProjectedLinear):
            self.ig1.lin_module.C = projection_norm_inf(self.ig1.lin_module.C, kappa=self.ig1.lin_module.kappa/self.ig1.lin_module.adj_rho)

        x = features
        x, jac_loss = self.ig1(x, compute_jac_loss=compute_jac_loss)
        smooth = x@self.L@x.t()
        print('Smooth', *(val.item() for val in torch.diag(smooth).sort()[0]),sep=', ')
        x = x.t()
        # x = F.normalize(x, dim=-1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V(x)
        return F.log_softmax(x, dim=1), jac_loss

"""
GraphClassificationIGNN - Single layer monIGNN for graph classification tasks.
"""
class GraphClassificationMonIGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, num_node, dropout, adj, sp_adj,
                linModule, fpMethod, invMethod, device,
                *args, **kwargs):
        super(GraphClassificationMonIGNN, self).__init__()

        # Linear Module
        monotone_args = (nfeat, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig1 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        monotone_args = (nhid, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig2 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        monotone_args = (nhid, nhid, num_node, adj, sp_adj, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.ig3 = MonotoneImplicitGraph(lin_module, nonlin_module, solver)

        # Initialize
        self.dropout = dropout
        self.V0 = nn.Linear(nhid, nhid)
        self.V1 = nn.Linear(nhid, nclass)
        pass

    def forward(self, features, adj, batch, *args, compute_jac_loss=False, sp_adj=None, **kwargs):

        with torch.no_grad():
            self.ig1.lin_module.set_adj(adj,sp_adj=sp_adj)
            self.ig1.lin_module.num_node = adj.shape[0]
            self.ig2.lin_module.set_adj(adj,sp_adj=None)
            self.ig2.lin_module.num_node = adj.shape[0]
            self.ig3.lin_module.set_adj(adj,sp_adj=None)
            self.ig3.lin_module.num_node = adj.shape[0]

        if isinstance(self.ig1.lin_module,ProjectedLinear):
            self.ig1.lin_module.C = projection_norm_inf(self.ig1.lin_module.C, kappa=self.ig1.lin_module.kappa/self.ig1.lin_module.adj_rho)
        x = features
        x, _ = self.ig1(x, compute_jac_loss=False)
        x, _ = self.ig2(x, compute_jac_loss=False)
        x, _ = self.ig3(x, compute_jac_loss=False)
        x = x.T
        x = global_add_pool(x,batch)
        x = F.relu(self.V0(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.V1(x)
        x = F.log_softmax(x,dim=1)
        return x, None


class porousMIGNN(nn.Module):

    def __init__(self,
                 node_dim: int,
                 edge_dim: int,
                 lifted_dim: int,  # bias function input dim
                 hidden_dim: int,  # hidden state dim (state of the fp equation)
                 output_dim: int,
                 activation: str,
                 num_hidden_gn: int,
                 linModule,
                 fpMethod,
                 invMethod,
                 device,
                 mlp_num_neurons: list = [128],
                 reg_num_neurons: list = [64, 32], **kwargs):
        super(porousMIGNN, self).__init__()

        self.encoder = AttnMPNN(node_in_dim=node_dim,
                                edge_in_dim=edge_dim,
                                node_hidden_dim=64,
                                edge_hidden_dim=64,
                                node_out_dim=lifted_dim,
                                edge_out_dim=1,  # will be ignored
                                num_hidden_gn=num_hidden_gn,
                                node_aggregator='sum',
                                mlp_params={'num_neurons': mlp_num_neurons,
                                            'hidden_act': activation,
                                            'out_act': activation})

        monotone_args = (lifted_dim, hidden_dim, None, None, None, linModule, fpMethod, invMethod, device)
        lin_module, nonlin_module, solver = monotone_setup(*monotone_args, **kwargs)
        self.fp_layer = MonotoneImplicitGraph(lin_module, nonlin_module, solver)
        # self.ignn = ImplicitGraph(lifted_dim, hidden_dim, None, kappa=0.9)
        self.decoder = MLP(hidden_dim, output_dim,
                           hidden_act=activation,
                           num_neurons=reg_num_neurons)

    def forward(self, g, nf, ef):
        """
        1. Transform input graph with node/edge features to the bias terms of the fixed point equations
        2. Solve fixed point eq
        3. Decode the solution with MLP.
        """

        unf, _ = self.encoder(g, nf, ef)

        adj = g.adj().to(nf.device)

        # adj = AugNorm(adj)
        adj =  LaplaceNorm(adj)
        self.fp_layer.lin_module.set_adj(adj,sp_adj=None)

        z,_ = self.fp_layer(unf.T)
        z = z.T
        pred = self.decoder(z)
        return pred




############################################################################################

def monotone_setup(nfeat, nhid, num_node, adj ,sp_adj, linModule, fpMethod, invMethod, device, nonlin='tanh',**kwargs):
    # Linear Module
    lin_args = {'adj':adj, 'sp_adj':sp_adj, 'device':device, 'invMethod':invMethod}
    lin_args = {**lin_args, **kwargs}
    if linModule == 'cayley':
        lin_module = CayleyLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    elif linModule == 'frob':
        lin_module = FrobeniusLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    elif linModule == 'diagd':
        lin_module = DiagDLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    elif linModule == 'expm':
        lin_module = ExpmLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    elif linModule == 'proj':
        adj_rho = None if (adj is None) else get_spectral_rad(adj)
        lin_module = ProjectedLinear(nfeat,nhid,num_node,adj_rho=adj_rho,**{**lin_args,**lin_args})
    elif linModule == 'symm':
        lin_module = SymmetricLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    elif linModule == 'skew':
        lin_module = SkewLinear(nfeat,nhid,num_node,**{**lin_args,**lin_args})
    else:
        raise NotImplementedError(f'Linear module is not supported, got: {linModule}')

    # Non-Linear Module
    if nonlin == 'relu':
        nonlin_module = ReLU()
    elif nonlin == 'tanh':
        nonlin_module = TanH()

    # Fixed-Point Solver
    fp_args = {'verbose':True}
    fp_args = {**fp_args, **kwargs}
    if fpMethod == 'fb':
        solver = ForwardBackward(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'fb+a':
        solver = ForwardBackwardAnderson(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'pr':
        solver = PeacemanRachford(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'pr+a':
        solver = PeacemanRachfordAnderson(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'dr':
        solver = DouglasRachford(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'dr+a':
        solver = DouglasRachfordAnderson(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'dr+h':
        solver = DouglasRachfordHalpern(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'pwr':
        solver = PowerMethod(lin_module,nonlin_module,**fp_args)
    elif fpMethod == 'pwr+a':
        solver = PowerMethodAnderson(lin_module,nonlin_module,**fp_args)
    else:
        raise NotImplementedError(f'Fixed point method is not supported, got: {fpMethod}')

    return lin_module, nonlin_module, solver