import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from deeprobust.graph import utils
from copy import deepcopy
from sklearn.metrics import f1_score
from numba import jit
import numpy as np
import warnings
warnings.filterwarnings("ignore")

class NewConvolution(Module):
    def __init__(self, in_features, out_features, lamba,with_bias=True):
        super(NewConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.lamba=lamba
        if with_bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        L=self.feature_smoothing(adj,input)
        n=adj.size(0)
        coeffient=torch.eye(n).to(self.weight.device)+self.lamba*L
        coeffient=torch.linalg.inv(coeffient)
        input=torch.mm(coeffient,input)
        output = torch.mm(input, self.weight)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
               
    def feature_smoothing(self, adj, X):
        
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj

        r_inv = r_inv  + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        L = r_mat_inv @ L @ r_mat_inv
        return L

class R_GUGNN(nn.Module):
    """ 2 Layer Graph Convolutional Network.

    Parameters
    ----------
    nfeat : int
        size of input feature dimension
    nhid : int
        number of hidden units
    nclass : int
        size of output dimension
    dropout : float
        dropout rate for GCN
    lr : float
        learning rate for GCN
    weight_decay : float
        weight decay coefficient (l2 normalization) for GCN.
        When `with_relu` is True, `weight_decay` will be set to 0.
    with_relu : bool
        whether to use relu activation function. If False, GCN will be linearized.
    with_bias: bool
        whether to include bias term in GCN weights.
    device: str
        'cpu' or 'cuda:0','cuda:1'...
"""

    def __init__(self, nfeat, nhid, nclass,c,lamba,beta,iterations, dropout=0.5, lr=0.01, weight_decay=5e-4,
            with_relu=True, with_bias=True, device=None):

        super(R_GUGNN, self).__init__()

        assert device is not None, "Please specify 'device'!"
        self.device = device
        self.nfeat = nfeat
        self.hidden_sizes = [nhid]
        self.nclass = nclass
        self.c=c
        self.lamba=lamba
        self.beta=beta
        self.iterations=iterations
        self.gc1 = NewConvolution(nfeat, nhid, lamba,with_bias=with_bias).to(self.device)
        self.gc2 = NewConvolution(nhid, nclass, lamba,with_bias=with_bias).to(self.device)
        self.dropout = dropout
        self.lr = lr
        if not with_relu:
            self.weight_decay = 0
        else:
            self.weight_decay = weight_decay
        self.with_relu = with_relu
        self.with_bias = with_bias
        self.output = None
        self.best_model = None
        self.best_output = None
        self.adj_norm = None
        self.features = None
        self.samples=None
        

    def forward(self, x, adj):
        if self.with_relu:
            x =self.gc1(x, adj)
            x = F.relu(x)
        else:
            x = self.gc1(x, adj)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, self.adj_norm)
        return F.log_softmax(x, dim=1)

    def initialize(self):
        """Initialize parameters of GCN.
        """
        self.gc1.reset_parameters()
        self.gc2.reset_parameters()

    def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=500, initialize=True, verbose=False, **kwargs):
        """Train the gcn model, when idx_val is not None, pick the best model according to the validation loss.

        Parameters
        ----------
        features :
            node features
        adj :
            the adjacency matrix. The format could be torch.tensor or scipy matrix
        labels :
            node labels
        idx_train :
            node training indices
        idx_val :
            node validation indices. If not given (None), GCN training process will not adpot early stopping
        train_iters : int
            number of training epochs
        initialize : bool
            whether to initialize parameters before training
        verbose : bool
            whether to show verbose logs
        """

        if initialize:
            self.initialize()
        features = features.to(self.device)
        adj = adj.to(self.device)
        labels = labels.to(self.device)
        adj_norm = adj+torch.eye(adj.shape[0]).to(self.device)
        self.adj_norm = adj_norm
        self.samples=self.adj_norm.size(0)
        self.features = features
        self.labels = labels
        self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)
            

    def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
        if verbose:
            print('=== training gnn model ===')
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        best_loss_val = 100
        best_acc_val = 0
        j=0
        
        while j<self.iterations:
            j+=1
            print(j)
            self._best_adj=self.adj_norm
            D = self.solveD(self.adj_norm)
            temp = self.solvetemp(self.samples, D.cpu().numpy(), self.features.cpu().numpy())
            temp = torch.from_numpy(temp).to(self.device)
            self.adj_norm = self.adj_norm - self.c / 2 * temp

            self.adj_norm = self.prox_nuclear(self.adj_norm, self.beta).to(self.device)
 
            self.adj_norm=self.adj_norm+torch.eye(self.samples).to(self.device)
            self.adj_norm[self.adj_norm < 0] = 0
            self.adj_norm[self.adj_norm > 1] = 1

            self.adj_norm = self.adj_norm.to(self.device)
        for i in range(train_iters):

            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            self.eval()
            output = self.forward(self.features, self.adj_norm)
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])
            if i %10==0:
                print('Epoch {}, training loss: {},validation loss: {},acc_val:{}'.format(i,loss_train.item(), loss_val.item(),acc_val))

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output
                weights = deepcopy(self.state_dict())

        if verbose:
            print('=== picking the best model according to the performance on validation ===')
        self.load_state_dict(weights)


    def test(self, idx_test):
        """Evaluate GCN performance on test set.

        Parameters
        ----------
        idx_test :
            node testing indices
        """
        self.eval()
        output = self.predict()
        loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
        acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return acc_test.item()


    def predict(self, features=None, adj=None):
        """By default, the inputs should be unnormalized adjacency

        Parameters
        ----------
        features :
            node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
        adj :
            adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.


        Returns
        -------
        torch.FloatTensor
            output (log probabilities) of GCN
        """

        self.eval()
        if features is None and adj is None:
            return self.forward(self.features, self.adj_norm)
        else:
            return self.forward(self.features, self.adj_norm)
    
    def solveD(self,adj):
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        D = D**(0.5)
        return D
     
    def feature_smoothing(self, adj, X):
        
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj
        # print(L)

        r_inv = r_inv  + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        # L = r_mat_inv @ L
        L = r_mat_inv @ L @ r_mat_inv
        
        return L
    
    @jit
    def solvetemp(self,n,D,input):
        temp=np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                temp[i,j]=np.sum((input[i]/D[i,i]-input[j]/D[j,j])**2)
        return temp

    def prox_nuclear(self, data, alpha):
        """Proximal operator for nuclear norm (trace norm).
        """
        U, S, V = np.linalg.svd(data.cpu())
        U, S, V = torch.FloatTensor(U), torch.FloatTensor(S), torch.FloatTensor(V)
        self.nuclear_norm = S.sum()
        # print("nuclear norm: %.4f" % self.nuclear_norm)

        diag_S = torch.diag(torch.clamp(S-alpha, min=0))
        return torch.matmul(torch.matmul(U, diag_S), V)
