import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from typing import Tuple, Optional
import os.path as osp
# from typing import Optional
from .utils import MixedDropout, sparse_matrix_to_torch


def normalize_matrix(t: np.ndarray) -> np.ndarray:
    norm = np.linalg.norm(t, ord=2)
    # print('norm : ', norm)
    return t / norm






# def calc_A_hat(adj_matrix: sp.spmatrix) -> sp.spmatrix:
#     nnodes = adj_matrix.shape[0]
#     A = adj_matrix + sp.eye(nnodes)
#     D_vec = np.sum(A, axis=1).A1
#     D_vec_invsqrt_corr = 1 / np.sqrt(D_vec)
#     D_invsqrt_corr = sp.diags(D_vec_invsqrt_corr)
    return D_invsqrt_corr @ A @ D_invsqrt_corr

def calc_Laplacian(adj_matrix: sp.spmatrix) -> sp.spmatrix:
    nnodes = adj_matrix.shape[0]
    A = adj_matrix
    D_vec = np.sum(A, axis=1).A1
    D = sp.diags(D_vec)
    return (D - A)



def calc_resolvent_exact(t_matrix: sp.spmatrix, omega: float = 1, normalizing_factor: float = 1.0, resolvent_path: Optional[str] = None) -> np.ndarray:
    nnodes = t_matrix.shape[0]

    if resolvent_path is None:
            t_matrix = calc_Laplacian(t_matrix)
            return np.linalg.inv(t_matrix.toarray()/normalizing_factor - omega * sp.eye(nnodes).toarray())

    if osp.exists(resolvent_path):
        resolvent_mat = np.load(resolvent_path)

    else:
        t_matrix = calc_Laplacian(t_matrix)
        resolvent_mat = np.linalg.inv(t_matrix.toarray()/normalizing_factor - omega * sp.eye(nnodes).toarray())
        np.save(resolvent_path, resolvent_mat)
    return resolvent_mat






class PPRExact(nn.Module):
    def __init__(self, adj_matrix: sp.spmatrix, alpha: float, drop_prob: float = None):
        super().__init__()

        ppr_mat = calc_ppr_exact(adj_matrix, alpha)
        self.register_buffer('mat', torch.FloatTensor(ppr_mat))

        if drop_prob is None or drop_prob == 0:
            self.dropout = lambda x: x
        else:
            self.dropout = MixedDropout(drop_prob)

    def forward(self, predictions: torch.FloatTensor, idx: torch.LongTensor):
        return self.dropout(self.mat[idx]) @ predictions


class OurPropagate(nn.Module):
    def __init__(
            self, adj_matrix: sp.spmatrix, omega: float = 1, drop_prob: float = None, normalizing_factor: float = 1.0,
            resolvent_path: Optional[str] = None):
        super().__init__()
        # print(omega)
        # print(normalizing_factor)
        ppr_mat = calc_resolvent_exact(adj_matrix, omega, normalizing_factor, resolvent_path)
        # print(f'shape of ppr_mat {ppr_mat.shape}')
        self.register_buffer('mat', torch.FloatTensor(ppr_mat))

        if drop_prob is None or drop_prob == 0:
            self.dropout = lambda x: x
        else:
            self.dropout = MixedDropout(drop_prob)

    def forward(self, predictions: torch.FloatTensor, idx: torch.LongTensor):
        return self.dropout(self.mat[idx]) @ predictions


