import sys
import torch
import numpy as np
from copy import deepcopy
import scipy.sparse as sp
from scipy.sparse import linalg
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F




def softmax_sample(logits): 
    
    return torch.nn.functional.softmax( logits , dim = -1)

def softmax(logits, hard=False):
   
    y = softmax_sample(logits)

    if not hard:
        return y
    
    if hard:
       
       shape = y.size()
       _, ind = y.max(dim=-1)
       y_hard = torch.zeros_like(y).view(-1, shape[-1])
       y_hard.scatter_(1, ind.view(-1, 1), 1)
       y_hard = y_hard.view(*shape)
    
       
       return y_hard


def normalize_adjacency(adjacency_matrix_tensor):



    assert adjacency_matrix_tensor.shape[-1] == adjacency_matrix_tensor.shape[-2]


    out_degree_tensor = torch.sum(adjacency_matrix_tensor, dim = -1)

    # Compute the in_degree vector

    in_degree_tensor = torch.sum(adjacency_matrix_tensor, dim = -2)

    # Invert them

    inv_sqrt_out_degree_tensor = torch.zeros_like(out_degree_tensor)

    out_degree_where = (out_degree_tensor != 0.)

    inv_sqrt_out_degree_tensor[out_degree_where] = 1./torch.sqrt(out_degree_tensor[out_degree_where])


    inv_sqrt_in_degree_tensor = torch.zeros_like(in_degree_tensor)

    in_degree_where = (in_degree_tensor != 0.)

    inv_sqrt_in_degree_tensor[in_degree_where] = 1./torch.sqrt(in_degree_tensor[in_degree_where])

   
    # Reshape so that the in_degree becomes a row vector and the out_degree

    # becomes a column vector

    inv_sqrt_out_degree_tensor = inv_sqrt_out_degree_tensor.unsqueeze(-1)

    inv_sqrt_in_degree_tensor = inv_sqrt_in_degree_tensor.unsqueeze(-2)

    # Multiply them to create a square matrix

    inv_sqrt_degree_tensor = inv_sqrt_in_degree_tensor * inv_sqrt_out_degree_tensor

    # And now we can just multiply this by the adjacency matrix and we're done

    return adjacency_matrix_tensor * inv_sqrt_degree_tensor


def get_laplacian_matrix(adj):
   
    # Apply the equation L = D - A
    N = adj.shape[-1]
    arr = torch.arange(N)
    L = -adj
    D = torch.sum(adj, dim=-1)
    L[..., arr, arr] = D

    # Normalize by the degree : L = D^-1 (D - A)
    
    Dinv = torch.zeros_like(L)
    Dinv[..., arr, arr] = D ** -1
    L = torch.matmul(Dinv, L)

    return L


#############  Access Graph Part #############

import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import pyplot, patches

from sklearn.metrics import confusion_matrix
def get_offdiag(sz):
    
    offdiag = torch.ones(sz, sz)
    for i in range(sz):
        offdiag[i, i] = 0
    
    return offdiag

def skip_diag_strided(A):
    m = A.shape[0]
    strided = np.lib.stride_tricks.as_strided
    s0,s1 = A.strides
    return strided(A.ravel()[1:], shape=(m-1,m), strides=(s0+s1,s1)).reshape(m,-1)


def calc_matrics(matrix, matrix_pred):
    matrix_pred = 1.0*(torch.sign(matrix_pred-0.5)+1)/2
    matrix = torch.from_numpy(matrix)
    matrix_pred = matrix_pred.to('cpu')
    num_nodes = matrix_pred.shape[0]

    

    err = torch.sum(torch.abs(matrix_pred * get_offdiag(num_nodes) - matrix * get_offdiag(num_nodes))).item()

   
    matrix = matrix.detach().numpy()
    matrix_pred = matrix_pred.detach().numpy()
    matrix = skip_diag_strided(matrix)
    matrix_pred = skip_diag_strided(matrix_pred)
    tn, fp, fn, tp = confusion_matrix(matrix.astype(int).reshape(-1),
                                      matrix_pred.astype(int).reshape(-1)).ravel()
   
    recall = tp / (tp + fn+0.0001)
    precision = tp / (tp + fp+0.0001)

    return recall, precision,err



