import math
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Tuple
from torch_geometric.typing import OptTensor
from torch_geometric.utils import get_laplacian



class SparseDropout(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p

    def forward(self, input):
        input_coal = input.coalesce()
        drop_val = F.dropout(input_coal._values(), self.p, self.training)
        return torch.sparse.FloatTensor(input_coal._indices(), drop_val, input.shape)


class MixedDropout(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.dense_dropout = nn.Dropout(p)
        self.sparse_dropout = SparseDropout(p)

    def forward(self, input):
        if input.is_sparse:
            return self.sparse_dropout(input)
        else:
            return self.dense_dropout(input)


class MixedLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # Our fan_in is interpreted by PyTorch as fan_out (swapped dimensions)
        nn.init.kaiming_uniform_(self.weight, mode='fan_out', a=math.sqrt(5))
        if self.bias is not None:
            _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_out)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        if self.bias is None:
            if input.is_sparse:
                res = torch.sparse.mm(input, self.weight)
            else:
                res = input.matmul(self.weight)
        else:
            if input.is_sparse:
                res = torch.sparse.addmm(self.bias.expand(input.shape[0], -1), input, self.weight)
            else:
                res = torch.addmm(self.bias, input, self.weight)
        return res

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
                self.in_features, self.out_features, self.bias is not None)


def sparse_matrix_to_torch(X):
    coo = X.tocoo()
    indices = np.array([coo.row, coo.col])
    return torch.sparse.FloatTensor(
            torch.LongTensor(indices),
            torch.FloatTensor(coo.data),
            coo.shape)


def matrix_to_torch(X):
    if sp.issparse(X):
        return sparse_matrix_to_torch(X)
    else:
        return torch.FloatTensor(X)
    












def calc_norm(
        L_edge_index: Tensor,
        L_edge_weight: OptTensor = None,
        dtype: Optional[torch.dtype] = None,
    ) -> float:
        if L_edge_weight == None:
            # print('edge_weight = None')
            L_edge_weight = torch.ones(edge_index.size(1), dtype=dtype,device=L_edge_index.device)
            

        A = torch.sparse_coo_tensor(L_edge_index, L_edge_weight)
        A = A.to_dense()
        # print(A.type)
        return torch.linalg.matrix_norm(A,2)






def get_resolvent(
    edge_index: Tensor,
    edge_weight: OptTensor = None,
    omega = -1.,
    singular_value_normalization: Optional[bool] = False,
    dtype: Optional[torch.dtype] = None,
    num_nodes: Optional[int] = None,
    normalizing_factor: Optional[float] = None,
) -> Tuple[Tensor, OptTensor]:
    assert omega < 0

    L_indices = get_laplacian(edge_index, edge_weight)
    L =  torch.sparse_coo_tensor(L_indices[0], L_indices[1])
    L = L.to_dense()
    if singular_value_normalization is not None:
        L = L/calc_norm(L_indices[0],L_indices[1])
    if normalizing_factor is not None:
        L = L/normalizing_factor
    
    identity = torch.eye(L.size(0)).to(device = edge_index.device)
    # T = L - self.omega*identity
    T = L - omega*identity
    R = torch.linalg.inv(T).to_sparse()

    Redge_index, Redge_attr = R.indices(), R.values()
    return Redge_index, Redge_attr
    
