import torch, math
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import scipy.sparse as sp
from scipy.sparse import coo_matrix

def process(mul_L_real, mul_L_imag, weight, X_real, X_imag):
    data = torch.spmm(mul_L_real, X_real)
    real = torch.matmul(data, weight)
    data = -1.0 * torch.spmm(mul_L_imag, X_imag)
    real += torch.matmul(data, weight)

    data = torch.spmm(mul_L_imag, X_real)
    imag = torch.matmul(data, weight)
    data = torch.spmm(mul_L_real, X_imag)
    imag += torch.matmul(data, weight)
    return torch.stack([real, imag])


class ChebConv(nn.Module):
    """
    The MagNet convolution operation.
    :param in_c: int, number of input channels.
    :param out_c: int, number of output channels.
    :param K: int, the order of Chebyshev Polynomial.
    :param L_norm_real, L_norm_imag: normalized laplacian of real and imag
    """

    def __init__(self, in_c, out_c, K, L_norm_real, L_norm_imag, bias=True):
        super(ChebConv, self).__init__()

        L_norm_real, L_norm_imag = L_norm_real, L_norm_imag

        # list of K sparsetensors, each is N by N
        self.mul_L_real = L_norm_real  # [K, N, N]
        self.mul_L_imag = L_norm_imag  # [K, N, N]

        self.weight = nn.Parameter(
            torch.Tensor(K + 1, in_c, out_c))  # [K+1, 1, in_c, out_c]

        stdv = 1. / math.sqrt(self.weight.size(-1))
        self.weight.data.uniform_(-stdv, stdv)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, out_c))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    def forward(self, data):
        """
        :param inputs: the input data, real [B, N, C], img [B, N, C]
        :param L_norm_real, L_norm_imag: the laplace, [N, N], [N,N]
        """
        X_real, X_imag = data[0], data[1]

        real = 0.0
        imag = 0.0

        future = []
        for i in range(len(self.mul_L_real)):  # [K, B, N, D]
            future.append(torch.jit.fork(process,
                                         self.mul_L_real[i],
                                         self.mul_L_imag[i],
                                         self.weight[i], X_real, X_imag))
        result = []
        for i in range(len(self.mul_L_real)):
            result.append(torch.jit.wait(future[i]))
        result = torch.sum(torch.stack(result), dim=0)

        real = result[0]
        imag = result[1]
        return real + self.bias, imag + self.bias


class complex_relu_layer(nn.Module):
    def __init__(self, ):
        super(complex_relu_layer, self).__init__()

    def complex_relu(self, real, img):
        mask = 1.0 * (real >= 0)
        return mask * real, mask * img

    def forward(self, real, img=None):
        # for torch nn sequential usage
        # in this case, x_real is a tuple of (real, img)
        if img == None:
            img = real[1]
            real = real[0]

        real, img = self.complex_relu(real, img)
        return real, img


class MagNet(nn.Module):
    def __init__(self, nfeatures: int, nclasses: int, L_norm_real, L_norm_imag, num_filter=2, K=2,
                  activation=False, layer=2, dropout=False):
        """
        :param in_c: int, number of input channels.
        :param hid_c: int, number of hidden channels.
        :param K: for cheb series
        :param L_norm_real, L_norm_imag: normalized laplacian
        """
        super(MagNet, self).__init__()

        chebs = [
            ChebConv(in_c=nfeatures, out_c=num_filter, K=K, L_norm_real=L_norm_real,
                     L_norm_imag=L_norm_imag)]
        if activation:
            chebs.append(complex_relu_layer())

        for i in range(1, layer):
            chebs.append(ChebConv(in_c=num_filter, out_c=num_filter, K=K,
                                  L_norm_real=L_norm_real,
                                  L_norm_imag=L_norm_imag))
            if activation:
                chebs.append(complex_relu_layer())

        self.Chebs = torch.nn.Sequential(*chebs)

        last_dim = 2
        self.Conv = nn.Conv1d(num_filter * last_dim, nclasses, kernel_size=1)
        self.dropout = dropout
        self.reg_params = list(self.parameters())

    def forward(self, real, idx):
        imag = real
        real, imag = self.Chebs((real, imag))
        x = torch.cat((real, imag), dim=-1)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)

        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x).squeeze(0).permute(1,0)[idx]
        # print("Final op shape : ", x.shape)
        x = F.log_softmax(x, dim=1)
        # print('shape before returniong : ', x.shape)
        return x







############################################################# UTILS #############################################################
###########################################
####### Sparse implementation #############
###########################################
def cheb_poly_sparse(A, K):
    K += 1
    N = A.shape[0]  # [N, N]
    #multi_order_laplacian = np.zeros([K, N, N], dtype=np.complex64)  # [K, N, N]
    multi_order_laplacian = []
    multi_order_laplacian.append( coo_matrix( (np.ones(N), (np.arange(N), np.arange(N))),
                                                    shape=(N, N), dtype=np.float32) )
    if K == 1:
        return multi_order_laplacian
    else:
        multi_order_laplacian.append(A)
        if K == 2:
            return multi_order_laplacian
        else:
            for k in range(2, K):
                multi_order_laplacian.append( 2.0 * A.dot(multi_order_laplacian[k-1]) - multi_order_laplacian[k-2] )

    return multi_order_laplacian


def hermitian_decomp_sparse(row, col, size, q=0.25, norm=True, laplacian=True,
                            max_eigen=2,
                            gcn_appr=False, edge_weight=None):
    if edge_weight is None:
        A = coo_matrix((np.ones(len(row)), (row, col)), shape=(size, size),
                       dtype=np.float32)
    else:
        A = coo_matrix((edge_weight, (row, col)), shape=(size, size),
                       dtype=np.float32)

    diag = coo_matrix((np.ones(size), (np.arange(size), np.arange(size))),
                      shape=(size, size), dtype=np.float32)
    if gcn_appr:
        A += diag

    A_sym = 0.5 * (A + A.T)  # symmetrized adjacency

    if norm:
        d = np.array(A_sym.sum(axis=0))[0]  # out degree
        d[d == 0] = 1
        d = np.power(d, -0.5)
        D = coo_matrix((d, (np.arange(size), np.arange(size))),
                       shape=(size, size), dtype=np.float32)
        A_sym = D.dot(A_sym).dot(D)

    if laplacian:
        Theta = 2 * np.pi * q * 1j * (A - A.T)  # phase angle array
        Theta.data = np.exp(Theta.data)
        if norm:
            D = diag
        else:
            d = np.sum(A_sym, axis=0)  # diag of degree array
            D = coo_matrix((d, (np.arange(size), np.arange(size))),
                           shape=(size, size), dtype=np.float32)
        L = D - Theta.multiply(A_sym)  # element-wise

    if norm:
        L = (2.0 / max_eigen) * L - diag

    return L


def geometric_dataset_sparse(q, K, adj: sp.csr_matrix, laplacian=True,
                             gcn_appr=False):


    sizes = adj.shape[0]
    row_indices, col_indices = [], []
    for i in range(sizes):
        col_start, col_end = adj.indptr[i], adj.indptr[i+1]
        col_idx = adj.indices[col_start:col_end]
        row_idx = np.full_like(col_idx, fill_value=i)
        row_indices.append(row_idx)
        col_indices.append(col_idx)
    f_node = np.concatenate(row_indices)
    e_node = np.concatenate(col_indices)

    L = hermitian_decomp_sparse(f_node, e_node, sizes, q, norm=True,
                                laplacian=laplacian,
                                max_eigen=2.0, gcn_appr=gcn_appr)

    multi_order_laplacian = cheb_poly_sparse(L, K)

    return multi_order_laplacian

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def tensorize_L(L, device):
    L_img = []
    L_real = []
    for i in range(len(L)):
        L_img.append( sparse_mx_to_torch_sparse_tensor(L[i].imag).to(device))
        L_real.append( sparse_mx_to_torch_sparse_tensor(L[i].real).to(device))

    return L_real, L_img