import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from deeprobust.graph import utils
from copy import deepcopy
import sys
from scipy import stats

import tensorly as tl
tl.set_backend('pytorch')
from tensorly.decomposition import parafac, tucker, tensor_train, matrix_product_state


import numpy as np
import scipy.sparse as sp
from numba import njit

class GraphConvolution(Module):
    """Simple GCN layer, similar to https://github.com/tkipf/pygcn
    """

    def __init__(self, in_features, out_features, with_bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if with_bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        """ Graph Convolutional Layer forward function
        """
        if input.data.is_sparse:
            support = torch.spmm(input, self.weight)
        else:
            support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias, support
        else:
            return output, support

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

class TGNN(nn.Module):
    def __init__(self, nfeat, nhid, nclass,
                 dropout=0.5, lr=0.01, weight_decay=5e-4,
                 with_relu=True, with_bias=True,
                 format='Tucker', rank=32, pros='knn', euclidean=True,
                 svd_rank=200, prune_thd=0.01,
                 lambda_t = 1e-4, weight_decay_t=1e-5, topk=32,
                 device=None):

        super(TGNN, self).__init__()

        assert device is not None, "Please specify 'device'!"
        self.device = device
        self.nfeat = nfeat
        self.hidden_sizes = [nhid]
        self.nclass = nclass
        self.gc1 = GraphConvolution(nfeat, nhid, with_bias=with_bias).to(device)
        self.gc2 = GraphConvolution(nhid, nclass, with_bias=with_bias).to(device)
        self.dropout = dropout
        self.lr = lr
        if not with_relu:
            self.weight_decay = 0
        else:
            self.weight_decay = weight_decay
        self.with_relu = with_relu
        self.with_bias = with_bias
        self.output = None
        self.best_model = None # not used
        self.best_output = None # not used
        self.best_A = None
        self.adj_norm = None
        self.features = None

        self.format = format
        self.rank = rank
        self.lamda_t = lambda_t
        self.weight_decay_t = weight_decay_t
        self.topk = topk

        self.pros = pros.split(',')

        self.euclidean = euclidean

        self.svd_rank = svd_rank
        self.prune_thd = prune_thd

        self.acc_lst = []

    def initialize(self):
        """Initialize parameters of GCN.
        """
        self.gc1.reset_parameters()
        self.gc2.reset_parameters()

    def truncatedSVD(self, data, k=50):
        """Truncated SVD on input data.

        Parameters
        ----------
        data :
            input matrix to be decomposed
        k : int
            number of singular values and vectors to compute.

        Returns
        -------
        numpy.array
            reconstructed matrix.
        """
        print('=== GCN-SVD: rank={} ==='.format(k))
        if sp.issparse(data):
            data = data.asfptype()
            U, S, V = sp.linalg.svds(data, k=k)
            print("rank_after = {}".format(len(S.nonzero()[0])))
            diag_S = np.diag(S)
        else:
            U, S, V = np.linalg.svd(data)
            U = U[:, :k]
            S = S[:k]
            V = V[:k, :]
            print("rank_before = {}".format(len(S.nonzero()[0])))
            diag_S = np.diag(S)
            print("rank_after = {}".format(len(diag_S.nonzero()[0])))

        return U @ diag_S @ V

    def fit(self, features, adj, labels, idx_train, idx_val, train_iters=200,
            initialize=True, verbose=False, normalize=True, patience=500, **kwargs):

        self.device = self.gc1.weight.device
        if initialize:
            self.initialize()

        if 'svd' in self.pros:
            self.svd_adj = (adj+torch.eye(adj.size(0)).to(self.device)).cpu().numpy()
            self.svd_adj = self.truncatedSVD(data=self.svd_adj, k=self.svd_rank)
            self.svd_adj = torch.FloatTensor(self.svd_adj).to(self.device)
            self.svd_adj = torch.clamp(self.svd_adj, 0, 1)
            self.svd_adj = self.svd_adj.unsqueeze(0)

        if 'prune' in self.pros:
            self.prune_adj = (adj+torch.eye(adj.size(0)).to(self.device)).cpu().numpy()
            self.prune_adj = self.drop_dissimilar_edges(features=features.cpu().numpy(),
                                                          adj=self.prune_adj)
            self.prune_adj = torch.FloatTensor(self.prune_adj.todense()).to(self.device)
            self.prune_adj = self.prune_adj.unsqueeze(0)

        if type(adj) is not torch.Tensor:
            features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
        else:
            features = features.to(self.device)
            adj = adj.to(self.device)
            labels = labels.to(self.device)

        self.adj_org = (adj + torch.eye(adj.size(0)).to(self.device))

        if normalize:
            if utils.is_sparse_tensor(adj):
                adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
            else:
                adj_norm = utils.normalize_adj_tensor(adj)
        else:
            adj_norm = adj

        self.adj_norm = adj_norm
        self.features = features
        self.labels = labels

        self._train_with_val(labels, idx_train, idx_val, train_iters, verbose)

    def forward(self, x, adj, A_bar=None, need_feats=False):
        x0 = x
        feats = [x0]

        adj1 = adj if A_bar is None else self.normalize(A_bar[0, :, :])
        x, support1 = self.gc1(x, adj1)

        if self.with_relu:
            x = F.relu(x)

        x = F.dropout(x, self.dropout, training=self.training)
        # print('training', self.training)
        x1 = x
        feats.append(x1)

        adj2 = adj if A_bar is None else self.normalize(A_bar[0, :, :])
        x, support2 = self.gc2(x, adj2)

        x2 = x
        feats.append(x2)

        if need_feats:
            return F.log_softmax(x, dim=1), feats

        return F.log_softmax(x, dim=1)

    def _norm_feat(self, X, p='l2'):
        if p == None:
            return X

        if p == 'l1':
            sum = 1 / (X.norm(p=1, dim=1, keepdim=True) + 1e-9).detach()
            sum[torch.isinf(sum)] = 0.
            X = X * sum
            return X

        if p == 'l2':
            sum = 1 / (X.norm(p=2, dim=1, keepdim=True) + 1e-9).detach()
            sum[torch.isinf(sum)] = 0.
            X = X * sum
            return X

        return None

    def normalize(self, adj):
        adj = torch.clamp(adj, 0, 1)
        normalized_adj = self._normalize(adj + torch.eye(adj.shape[0]).to(self.device))
        return normalized_adj

    def _normalize(self, mx):
        mx = torch.clamp(mx, 0, 1)
        rowsum = torch.abs(mx).sum(1)
        r_inv = rowsum.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv).detach()
        mx = r_mat_inv @ mx
        mx = mx @ r_mat_inv
        return mx

    def _knn(self, X, k=None):
        k = k if k is not None else self.topk
        topk = X.topk(k, dim=1)[1]
        ret = torch.zeros_like(X)
        for i in range(ret.size(0)):
            ret[i, topk[i]] = 1
        return ret

    def _get_knn(self, X, k=None):
        if self.is_binary:
            intersection = torch.mm(X, X.t())
            union = X.size(1) - torch.mm((1 - X), (1 - X).t())
            smooth = 1
            S = (intersection + smooth) / (union + smooth)
        else:
            X = self._norm_feat(X, p='l2')
            S = -torch.cdist(X, X)
        Z = self._knn(S, k=k).clone().detach().unsqueeze(0)
        return Z

    def _init_td(self):
        A = torch.stack([self.adj_org for _ in range(1)], dim=0)
        self.OA = A
        _, X = self.forward(self.features, self.adj_norm, A_bar=A, need_feats=True)

        self.T = [A]
        self.is_binary = [1]
        self.reg_name = ['ADJ']

        if 'knn' in self.pros:
            Z = self._get_knn(X[0])
            self.Z = Z
            self.T.append(self.Z)
            self.is_binary.append(1)
            self.reg_name.append('KNN%d'%self.topk)

        if 'svd' in self.pros:
            self.T.append(self.svd_adj)
            self.is_binary.append(0)
            self.reg_name.append('SVD%d'%self.svd_rank)

        if 'prune' in self.pros:
            self.T.append(self.prune_adj)
            self.is_binary.append(1)
            self.reg_name.append('PRUNE%.6f'%self.prune_thd)

        self.T = torch.cat(self.T, dim=0)

        T = self.T
        self.A = T[:1, :, :]

        if self.format == 'CP':
            weights, factors = parafac(T.transpose(1, 0), self.rank, init='random', normalize_factors=True)
            self.register_parameter('f0', Parameter(torch.zeros_like(factors[0]).to(self.device), requires_grad=True))
            self.register_parameter('f1', Parameter(torch.zeros_like(factors[1]).to(self.device), requires_grad=True))
            self.register_parameter('f2', Parameter(torch.zeros_like(factors[2]).to(self.device), requires_grad=True))
            self.register_parameter('of0', Parameter((factors[0]).to(self.device), requires_grad=False))
            self.register_parameter('of1', Parameter((factors[1]).to(self.device), requires_grad=False))
            self.register_parameter('of2', Parameter((factors[2]).to(self.device), requires_grad=False))
            self.register_parameter('weights', Parameter(torch.zeros_like(weights).to(self.device), requires_grad=True))
            self.register_parameter('oweights', Parameter((weights).to(self.device), requires_grad=False))
        elif self.format == 'Tucker':
            core, factors = tucker(T, rank=self.rank, init='random')
            self.register_parameter('f0', Parameter(torch.zeros_like(factors[0]).to(self.device), requires_grad=True))
            self.register_parameter('f1', Parameter(torch.zeros_like(factors[1]).to(self.device), requires_grad=True))
            self.register_parameter('f2', Parameter(torch.zeros_like(factors[2]).to(self.device), requires_grad=True))
            self.register_parameter('of0', Parameter((factors[0]).to(self.device), requires_grad=False))
            self.register_parameter('of1', Parameter((factors[1]).to(self.device), requires_grad=False))
            self.register_parameter('of2', Parameter((factors[2]).to(self.device), requires_grad=False))
            self.register_parameter('core',Parameter(torch.zeros_like(core).to(self.device), requires_grad=True))
            self.register_parameter('ocore', Parameter((core).to(self.device), requires_grad=False))
        elif self.format == 'TT':
            factors = matrix_product_state(T, rank=[1, self.rank, self.rank, 1])
            print([_.size() for _ in factors])
            self.register_parameter('f0', Parameter(torch.zeros_like(factors[0]).to(self.device), requires_grad=True))
            self.register_parameter('f1', Parameter(torch.zeros_like(factors[1]).to(self.device), requires_grad=True))
            self.register_parameter('f2', Parameter(torch.zeros_like(factors[2]).to(self.device), requires_grad=True))
            self.register_parameter('of0', Parameter((factors[0]).to(self.device), requires_grad=False))
            self.register_parameter('of1', Parameter((factors[1]).to(self.device), requires_grad=False))
            self.register_parameter('of2', Parameter((factors[2]).to(self.device), requires_grad=False))

    def forward_T(self):
        if self.format == 'CP':
            factors = [self.f0 + self.of0, self.f1 + self.of1, self.f2 + self.of2]
            core = self.weights + self.oweights
            T_bar = tl.cp_to_tensor((core, factors))
            T_bar = T_bar.transpose(1, 0) # avoid svd oom
        elif self.format == 'Tucker':
            factors = [self.f0 + self.of0, self.f1 + self.of1, self.f2 + self.of2]
            core = self.core + self.ocore
            T_bar = tl.tucker_to_tensor((core, factors))
        elif self.format == 'TT':
            factors = [self.f0 + self.of0, self.f1 + self.of1, self.f2 + self.of2]
            T_bar = tl.tt_to_tensor(factors)

        self.T_bar = T_bar
        self.A_bar = T_bar[:1]

    def rec_loss(self, X, Y):
        X = torch.clamp(X, 0, 1)
        return ((X - Y) ** 2).mean()

    def _gcn_step(self, i, optimizer_G, optimizer_T, idx_train, labels):
        self.train()
        optimizer_G.zero_grad()
        if optimizer_T is not None:
            optimizer_T.zero_grad()

        self.forward_T()

        output, X = self.forward(self.features, self.adj_norm, A_bar=self.A_bar, need_feats=True)

        loss_cls = F.nll_loss(output[idx_train], labels[idx_train])

        loss_reg, loss_adj = 0, 0

        if optimizer_T is not None:
            for r in range(self.T.size(0)):
                S0_normed = torch.clamp(self.T_bar[r], 0, 1)
                S0_normed_t = torch.clamp(self.T[r].detach(), 0, 1)

                loss = self.rec_loss(S0_normed, S0_normed_t)

                if r == 0:
                    loss_adj = loss_adj + self.lamda_t * loss
                else:
                    loss_reg = loss_reg + self.lamda_t * loss
        sys.stdout.flush()

        loss_train = loss_cls + (loss_reg * self.lamda_t + loss_adj * self.lamda_t)

        loss_train.backward()
        optimizer_G.step()
        if optimizer_T is not None:
            optimizer_T.step()

    def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
        if verbose:
            print('=== Initialization ===')
        self.eval()
        self._init_td()
        self.forward_T()

        if verbose:
            print('=== pre-training tensor model ===')

        g_parameters = []
        t_parameters = []
        for name, param in self.named_parameters():
            if name[:2] == 'gc':
                g_parameters.append(param)
            else:
                t_parameters.append(param)
        optimizer_T = optim.Adam(t_parameters, lr=self.lr, weight_decay=self.weight_decay_t) if len(t_parameters) > 0 else None
        optimizer_G = optim.Adam(g_parameters, lr=self.lr, weight_decay=self.weight_decay)

        if verbose:
            print('=== training gcn model ===')
        best_loss_val = 1e10
        best_acc_val = 0

        for i in range(train_iters):
            self._gcn_step(i, optimizer_G, optimizer_T, idx_train, labels)

            self.eval()
            self.forward_T()
            output = self.forward(self.features, self.adj_norm, self.A_bar)
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])
            loss_cls = F.nll_loss(output[idx_train], labels[idx_train])
            acc_train = utils.accuracy(output[idx_train], labels[idx_train])

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                self.best_A = self.A_bar.clone().detach()
                weights = deepcopy(self.state_dict())

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output
                self.best_A = self.A_bar.clone().detach()
                weights = deepcopy(self.state_dict())

            print('Epoch: {:04d}'.format(i + 1),
                  'loss_val: {:.4f}'.format(loss_val.item()),
                  'loss_train: {:.4f}'.format(loss_cls.item()),
                  'acc_val: {:.4f}'.format(acc_val.item()),
                  'acc_train: {:.4f}'.format(acc_train.item()))

        if verbose:
            print('=== picking the best model according to the performance on validation ===')
        self.load_state_dict(weights)

    def test(self, idx_test):
        """Evaluate GCN performance on test set.

        Parameters
        ----------
        idx_test :
            node testing indices
        """
        self.eval()
        output = self.predict()
        # output = self.output
        loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
        acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return acc_test


    def predict(self):
        # return self.output
        self.eval()
        return self.forward(self.features, self.adj_norm, A_bar=self.best_A)

    def drop_dissimilar_edges(self, features, adj, metric='similarity'):
        """Drop dissimilar edges.(Faster version using numba)
        """
        if not sp.issparse(adj):
            adj = sp.csr_matrix(adj)

        adj_triu = sp.triu(adj, format='csr')

        if metric == 'distance':
            removed_cnt = dropedge_dis(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.prune_thd)
        else:
            if self.euclidean:
                removed_cnt = dropedge_prune(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.prune_thd)
            else:
                removed_cnt = dropedge_cosine(adj_triu.data, adj_triu.indptr, adj_triu.indices, features, threshold=self.prune_thd)
        print('removed %s edges in the original graph' % removed_cnt)
        modified_adj = adj_triu + adj_triu.transpose()
        return modified_adj

    def _drop_dissimilar_edges(self, features, adj):
        """Drop dissimilar edges. (Slower version)
        """
        if not sp.issparse(adj):
            adj = sp.csr_matrix(adj)
        modified_adj = adj.copy().tolil()

        # preprocessing based on features
        print('=== GCN-Jaccrad ===')
        edges = np.array(modified_adj.nonzero()).T
        removed_cnt = 0
        for edge in tqdm(edges):
            n1 = edge[0]
            n2 = edge[1]
            if n1 > n2:
                continue

            if self.euclidean:
                J = self._prune_similarity(features[n1], features[n2])

                if J < self.prune_thd:
                    modified_adj[n1, n2] = 0
                    modified_adj[n2, n1] = 0
                    removed_cnt += 1
            else:
                # For not binary feature, use cosine similarity
                C = self._cosine_similarity(features[n1], features[n2])
                if C < self.prune_thd:
                    modified_adj[n1, n2] = 0
                    modified_adj[n2, n1] = 0
                    removed_cnt += 1
        print('removed %s edges in the original graph' % removed_cnt)
        return modified_adj

    def _prune_similarity(self, a, b):
        intersection = a.multiply(b).count_nonzero()
        J = intersection * 1.0 / (a.count_nonzero() + b.count_nonzero() - intersection)
        return J

    def _cosine_similarity(self, a, b):
        inner_product = (features[n1] * features[n2]).sum()
        C = inner_product / np.sqrt(np.square(a).sum() + np.square(b).sum())
        return C

def dropedge_prune(A, iA, jA, features, threshold):
    removed_cnt = 0
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            # print(row, jA[i], A[i])
            n1 = row
            n2 = jA[i]
            a, b = features[n1], features[n2]
            intersection = np.count_nonzero(np.multiply(a, b))
            # intersection = a.multiply(b).count_nonzero()
            J = intersection * 1.0 / (np.count_nonzero(a) + np.count_nonzero(b) - intersection)
            # J = intersection * 1.0 / (a.count_nonzero() + b.count_nonzero() - intersection)

            if J < threshold:
                A[i] = 0
                # A[n2, n1] = 0
                removed_cnt += 1
    return removed_cnt


@njit
def dropedge_cosine(A, iA, jA, features, threshold):
    removed_cnt = 0
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            # print(row, jA[i], A[i])
            n1 = row
            n2 = jA[i]
            a, b = features[n1], features[n2]
            inner_product = (a * b).sum()
            C = inner_product / (np.sqrt(np.square(a).sum() + np.square(b).sum())+ 1e-6)

            if C < threshold:
                A[i] = 0
                # A[n2, n1] = 0
                removed_cnt += 1
    return removed_cnt

@njit
def dropedge_dis(A, iA, jA, features, threshold):
    removed_cnt = 0
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            # print(row, jA[i], A[i])
            n1 = row
            n2 = jA[i]
            C = np.linalg.norm(features[n1] - features[n2])
            if C > threshold:
                A[i] = 0
                # A[n2, n1] = 0
                removed_cnt += 1

    return removed_cnt

@njit
def dropedge_both(A, iA, jA, features, threshold1=2.5, threshold2=0.01):
    removed_cnt = 0
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            # print(row, jA[i], A[i])
            n1 = row
            n2 = jA[i]
            C1 = np.linalg.norm(features[n1] - features[n2])

            a, b = features[n1], features[n2]
            inner_product = (a * b).sum()
            C2 = inner_product / (np.sqrt(np.square(a).sum() + np.square(b).sum())+ 1e-6)
            if C1 > threshold1 or threshold2 < 0:
                A[i] = 0
                # A[n2, n1] = 0
                removed_cnt += 1

    return removed_cnt
