import torch
import numpy as np
import scipy.sparse as sp


def sparse_I_n_tilde_numpy(mask):
    n = mask.shape[0]
    edge_index = torch.stack([torch.arange(n)[mask], torch.arange(n)[mask]])
    edge_value = torch.ones(edge_index.shape[1])
    return sp.csc_matrix((edge_value, edge_index), shape=(n,n))


def multiply_with_Gksn_inductive(G, s: list, v, mask, normalize = False):
    n = mask.shape[0]
    # G_csc = convert_sparse_coo_to_csc(G)
    G_left = G[:, mask]
    G_right = G_left.T
    G_n = G[mask, :][:, mask]

    out = torch.tensor(s[1]* (G_n @ v)) + s[0]*v
    w = torch.zeros((n,1))
    if len(s) > 2:
        w = multiply_G_ks_v(G, s[2:], G_left@v, normalize = normalize)
        if normalize:
            out += torch.tensor(1/n*G_right@ w)

        else:
            out += torch.tensor(G_right@ w)
    return out,w 
def multiply_G_ks_v(G, s, v, normalize = False):
    '''
    Input: G, s ,v
    Return: s_0v + s_1Gv + s_2G^2v + ... + s_kG^kv
    If normalize = True, return s_0v + s_1Gv/n + s_2G^2v/n^2 + ... + s_kG^kv/n^k
    '''
    n = G.shape[0]
    
    if normalize:
        G_ks_v = s[-1]* v 
        for i in range(1, len(s)):
            G_ks_v = 1/n * (G@ G_ks_v) + s[-i-1]*v

        return G_ks_v
    else:
        G_ks_v = s[-1]* v 
        for i in range(1, len(s)):
            G_ks_v = (G@ G_ks_v) + s[-i-1]*v

        return G_ks_v
    
class STKR_inductive():
    def __init__(self) -> None:
        self.alpha_hat = 0
        self.v = 0
        self.s = []
        self.eps = 10
    def fit(self, G_K, y, train_labeled_mask_GK, s, normalize = False, beta = 10, gamma = 0.0001, max_iter = 200, target_eps = 1e-3):
        alpha_hat = torch.zeros_like(y)
        n_labeled = train_labeled_mask_GK.sum()
        num_iter = 0
        eps = self.eps
        while eps > target_eps and num_iter < max_iter:
            u, v = multiply_with_Gksn_inductive(G_K, s = s, v = alpha_hat, mask = train_labeled_mask_GK, normalize = normalize)
            u +=  n_labeled*beta* alpha_hat 
            if torch.isnan(u).sum() > 0:
                raise ValueError('u contains NaN values, reduce gamma')
            alpha_hat -= gamma * (u - y)
            eps = ((u - y).square().sum()/y.square().sum()).item()
            num_iter += 1
        # print('epsilon = ', eps)

        self.alpha_hat = alpha_hat
        self.v = v
        self.s = s
        self.eps = eps
    def predict(self,v_K, train_labeled_mask_vK ):
        # v = G[val_mask, :][:, train_labeled_mask + train_unlabeled_mask]
        out = v_K @ self.v + self.s[1] * v_K[:, train_labeled_mask_vK] @ self.alpha_hat
        return out
    
    def get_eps(self):
        return self.eps
    


class STKR_inductive_sinv():
    def __init__(self) -> None:
        self.theta = 0
        self.v = 0
        self.s_inv = []
        self.r = 1
        self.eps = 10
        self.num_iter = 0
    def fit(self, G_K, y, train_labeled_mask_GK, s_inv, r,  normalize = False, beta = 10, gamma = 0.0001, max_iter = 200, target_eps = 1e-3):
        # Initialize theta
        y_tilde = np.zeros((G_K.shape[0], y.shape[1]))
        y_tilde[train_labeled_mask_GK, :] = y
        theta =  np.zeros_like(y_tilde)
        n = G_K.shape[0]
        n_labeled = train_labeled_mask_GK.sum()
        num_iter = 0
        eps = self.eps

        while eps > target_eps and num_iter < max_iter:
            # u = M(theta)
            u = n_labeled*beta*multiply_G_ks_v(G_K, s = s_inv, v = theta, normalize = normalize) # n * beta * Q theta
            I_n_tilde = sparse_I_n_tilde_numpy(train_labeled_mask_GK)
            u += n*I_n_tilde @ multiply_G_ks_v(G_K, s = [int(i >=r) for i in range(r+1)], v = theta, normalize = normalize) # (n+m) I_tilde G_K^r
            if torch.isnan(torch.tensor(u)).sum() > 0:
                raise ValueError('u contains NaN values, reduce gamma')
            
            theta -= gamma * (u - y_tilde)
            eps = (torch.tensor(u - y_tilde).square().sum()/torch.tensor(y_tilde).square().sum()).item()
            num_iter += 1

        v = multiply_G_ks_v(G_K, s = [int(i >=r-1) for i in range(r)], v = theta, normalize = normalize)

        self.theta = theta
        self.v = v
        self.s_inv = s_inv
        self.r = r
        self.eps = eps
        self.num_iter = num_iter


    def predict(self,v_K, train_labeled_mask_vK ):
        # v = G[val_mask, :][:, train_labeled_mask + train_unlabeled_mask]
        out = v_K @ self.v
        return out
    
    def get_eps(self):
        return self.eps
    
    def get_num_iter(self):
        return self.num_iter