import torch
from torch import nn
import numpy as np
from gnn import GraphNet
import math
from sklearn.metrics import pairwise_distances as pair
from sklearn.preprocessing import normalize
import scipy.sparse as sp

def normalization(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def create_graph(x, p=0.5, mode='heat', mask=None):
    # x: B, L, V
    x = x.transpose(1, 2)
    x = x.cpu().numpy()
    idx = np.arange(x.shape[1])
    
    adjs = torch.zeros((x.shape[0], x.shape[1], x.shape[1]))
    for i in range(len(x)):
        dist = None
        if mode == 'heat':
            dist = -0.5 * pair(x[i, :, :]) ** 2
            dist = np.exp(dist)
        elif mode == 'ncos':
            x[i, :, :] = normalize(x[i, :, :], axis=1, norm='l1')
            dist = np.dot(x[i, :, :], x[i, :, :].T)

        if mask is not None:
            k = np.random.randint(math.ceil(x.shape[1] * p), x.shape[1])
        elif mask is None:
            k = math.ceil(x.shape[1] * p)

        inds = []
        for j in range(dist.shape[0]):
            ind = np.argpartition(dist[j, :], -(k+1))[-(k+1):]
            inds.append(ind)

        edges = []
        for j, v in enumerate(inds):
            for vv in v:
                if vv == j:
                    pass
                else:
                    edges.append([j, vv])
        edges = np.array(edges)

        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(x.shape[1], x.shape[1]), dtype=np.float32)

        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
        adj = adj + sp.eye(adj.shape[0])
        adj = normalization(adj)
        adj = sparse_mx_to_torch_sparse_tensor(adj)

        adjs[i] = adj.to_dense()
    return adjs

class SEncoder(nn.Module):
    def __init__(self, input_dims, output_dims, hidden_dims=64, p=0.5, mask_mode='binomial'):
        super().__init__()
        self.input_dims = input_dims # sequence length
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.p = p
        self.mask_mode = mask_mode
        self.spatial_encoder = GraphNet(self.input_dims,self.output_dims, self.hidden_dims)
        self.repr_dropout = nn.Dropout(p=0.1)

    def forward(self, x, mask=None):  # x: B x L x V
        nan_mask = ~x.isnan().any(axis=-1)
        x[~nan_mask] = 0

        if mask is None:
            if self.training:
                mask=self.mask_mode
            else:
                mask = None

        adj = create_graph(x, p=self.p, mask=mask) # B V V 
        x = self.repr_dropout(self.spatial_encoder(x, adj)) # B V Zo

        return x
        