import torch
from torch import Tensor
import scipy.sparse as sp


from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import one_hot, spmm
from typing import Callable, Optional


# utility
def sparse_eye(n):
    edge_index = torch.stack([torch.arange(n), torch.arange(n)])
    edge_value = torch.ones(n)
    return torch.sparse_coo_tensor(edge_index, edge_value, (n, n))

# sparse I_n_tilde
def sparse_I_n_tilde(mask):
    n = mask.shape[0]
    edge_index = torch.stack([torch.arange(n)[mask], torch.arange(n)[mask]])
    edge_value = torch.ones(edge_index.shape[1])
    return torch.sparse_coo_tensor(edge_index, edge_value, (n, n))


def convert_sparse_coo_to_csc(coo_tensor):
    # Coalesce the sparse COO tensor
    coo_tensor = coo_tensor.coalesce()

    # Extract non-zero values, row indices, and column indices
    values = coo_tensor.values()
    row_indices = coo_tensor.indices()[0]
    col_indices = coo_tensor.indices()[1]

    # Create an empty SciPy CSC matrix
    csc_matrix = sp.csc_matrix((values.numpy(), (row_indices.numpy(), col_indices.numpy())), shape=coo_tensor.size())

    return csc_matrix

def multiply_G_ks_v(G, s, v, normalize = True):
    '''
    Input: G, s ,v
    Return: s_0v + s_1Gv + s_2G^2v + ... + s_kG^kv
    If normalize = True, return s_0v + s_1Gv/n + s_2G^2v/n^2 + ... + s_kG^kv/n^k
    '''
    n = G.shape[0]
    
    if normalize:
        G_ks_v = s[-1]* v 
        for i in range(1, len(s)):
            G_ks_v = 1/n * (G@ G_ks_v) + s[-i-1]*v

        return G_ks_v
    else:
        G_ks_v = s[-1]* v 
        for i in range(1, len(s)):
            G_ks_v = (G@ G_ks_v) + s[-i-1]*v

        return G_ks_v
    
def multiply_with_Gksn(G, s: list, v, mask, normalize):
    n_labeled = mask.sum().item()
    n = mask.shape[0]
    # G_csc = convert_sparse_coo_to_csc(G)
    G_left = G[:, mask]
    G_right = G_left.T
    G_n = G[mask, :][:, mask]

    out = torch.tensor(s[1]* (G_n @ v)) + s[0]*v
    if len(s) > 2:
        if normalize:
            out += torch.tensor(1/n*G_right@ (multiply_G_ks_v(G, s[2:], G_left@v, normalize = True)))

        else:
            out += torch.tensor(G_right@ (multiply_G_ks_v(G, s[2:], G_left@v, normalize = False)))
    return out


class Muliply_with_Gks(MessagePassing):
    ''' Input: G, s ,v
        Return: s_0v + s_1Gv + s_2G^2v + ... + s_kG^kv
        If normalize = True, return s_0v + s_1Gv/n + s_2G^2v/n^2 + ... + s_kG^kv/n^k
    '''
    def __init__(self, edge_index, norm, num_nodes, normalize = False):
        super().__init__(aggr='add')  
        self.edge_index = edge_index
        self.norm = norm
        self.n = num_nodes
        self.normalize = normalize

    def forward(self, x, s):
        if self.normalize:
            x = s[-1]*x
            for i in range(1, len(s)):
                x = self.propagate(self.edge_index, x=x, norm=self.norm/self.n) + s[-i-1]*x
            return x
        else:
            x = s[-1]*x
            for i in range(1, len(s)):
                x = self.propagate(self.edge_index, x=x, norm=self.norm) + s[-i-1]*x
            return x

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)



class STKR_inv(MessagePassing):
    r"""
    Args:
        num_layers (int): The number of propagations.
        xi (list): polynomial [xi_0, xi_1, xi_2] Q = xi_0  + xi_1 x + xi_2 x**2 + ...
        r (int): degree of polynomial that s^{-1}(x)x^r = Q
        gamma (float): learning rate
    """
    def __init__(self, num_layers: int, xi:list, r:int, alpha:float = 0.9):
        super().__init__(aggr='add')
        self.num_layers = num_layers
        self.xi = xi
        self.r = r
        self.alpha = alpha

    @torch.no_grad()
    def forward(
        self,
        y: Tensor,
        edge_index: Adj,
        mask: OptTensor = None,
        edge_weight: OptTensor = None,
        post_step: Optional[Callable[[Tensor], Tensor]] = None,
        normalize = False,
        beta = 1
    ) -> Tensor:
        r"""
        Args:
            y (torch.Tensor): The ground-truth label information
                :math:`\mathbf{Y}`.
            edge_index (torch.Tensor or SparseTensor): The edge connectivity.
            mask (torch.Tensor, optional): A mask or index tensor denoting
                which nodes are used for label propagation.
                (default: :obj:`None`)
            edge_weight (torch.Tensor, optional): The edge weights.
                (default: :obj:`None`)
            post_step (callable, optional): A post step function specified
                to apply after label propagation. If no post step function
                is specified, the output will be clamped between 0 and 1.
                (default: :obj:`None`)
            normalize (bool, optional): If set to :obj:`True`, the edge weight
                will be normalized by the number of nodes.
            beta (float, optional): The hyperparameter for regurlarization.
        """
        if y.dtype == torch.long and y.size(0) == y.numel():
            y = one_hot(y.view(-1))

        out = y
        if mask is not None:
            out = torch.zeros_like(y)
            out[mask] = y[mask]

        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
            edge_index = gcn_norm(edge_index, add_self_loops=False)
        elif isinstance(edge_index, Tensor) and edge_weight is None:
            edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
                                               add_self_loops=False)

        n = mask.shape[0]               # n + m
        n_labeled = mask.sum().item()   # n
        mgks = Muliply_with_Gks(edge_index, edge_weight, n, normalize = normalize)

        # Compute G_K * y
        out = self.propagate(edge_index, x=out, edge_weight=edge_weight,size=None)
        out.clamp_(0., 1.)

       

        res = (1 - self.alpha) * out

        # Helper function for multiplying a polynomial G_ks with a vector v
        
        I_n_tilde = sparse_I_n_tilde(mask)
        xi_tilde = [-self.xi[i]/self.xi[0] for i in range(len(self.xi))]
        xi_tilde[0] = 0
        power_r_1 = [int(i == self.r - 1) for i in range(self.r)]

        for _ in range(self.num_layers):

            out_clone = out.clone()
            if normalize:
                out = mgks(out_clone, xi_tilde) - 1/(n_labeled * beta * self.xi[0]) * mgks((I_n_tilde @ mgks(out_clone, power_r_1)), [0,1])
            else:
                out = mgks(out_clone, xi_tilde) - n/(n_labeled * beta * self.xi[0]) * mgks((I_n_tilde @ mgks(out_clone, power_r_1)), [0,1])
            out.mul_(self.alpha).add_(res)
            if post_step is not None:
                out = post_step(out)
            else:
                # pass
                # out.clamp_(-10000,10000)
                out.clamp_(0., 1.)

        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(xi={self.xi}, '
                f'r={self.r})')
    


class STKR(MessagePassing):

    def __init__(self, num_layers: int, s:list, alpha:float):
        super().__init__(aggr='add')
        self.num_layers = num_layers
        self.s = s
        self.alpha = alpha

    @torch.no_grad()
    def forward(
        self,
        y: Tensor,
        edge_index: Adj,
        mask: OptTensor = None,
        edge_weight: OptTensor = None,
        post_step: Optional[Callable[[Tensor], Tensor]] = None,
        normalize = False,
        beta = 1
    ) -> Tensor:

        if y.dtype == torch.long and y.size(0) == y.numel():
            y = one_hot(y.view(-1))

        out = y
        if mask is not None:
            out = torch.zeros_like(y) 
            out[mask] = y[mask]

        out_labeled = out[mask]

        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
            edge_index = gcn_norm(edge_index, add_self_loops=False)
        elif isinstance(edge_index, Tensor) and edge_weight is None:
            edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
                                               add_self_loops=False)

        n = mask.shape[0]               # n + m
        n_labeled = mask.sum().item()   # n

        # Compute G_left, G_right, G_n with slicing
        G = sp.csc_matrix((edge_weight, edge_index), shape=(n,n))
        mgks = Muliply_with_Gks(edge_index, edge_weight, n, normalize = normalize)
       

        res = (1 - self.alpha) * out_labeled
        for _ in range(self.num_layers):
            out_labeled = -1/(n_labeled*beta) * multiply_with_Gksn(G, self.s, out_labeled, mask, normalize = False)
            out_labeled.add_(res)
            if post_step is not None:
                out_labeled = post_step(out_labeled)
            else:
                # pass
                # out.clamp_(-10000,10000)
                out.clamp_(0., 1.)

        # Inference 
        out[mask] = out_labeled
        out = mgks(out, self.s)
        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(s={self.s}')