import numpy as np
import math
import tensorflow as tf
import scipy.sparse as sp

def repeat_pooling(X, A, pooling, ratio=0.5, k=None):
    #print(X, A)
    if k is None:
        k = np.ceil(ratio * A[0].shape[0])
    A_original = A.copy()
    A_out = A.copy()
    #S_out = sp.eye(A[0].shape[0])
    return_list = False
    S_out = None
    for i in range(int(math.log(ratio, 0.5))):
        X, A, A_out, S = pooling(X, A_out)
        #print(S.shape)
        #print(S)
        if isinstance(S, list):
            S = S[0]
            return_list = True
        if S_out is None:
            S_out = S
        else:
            S_out = S_out @ S
        
        #print(A_out, X, S_out)
        #print(X.shape)
        #print(S_out.shape)

    if return_list:
        S_out = [S_out]
    #print(X, A_out)
    #print()

    return X, A_original, A_out, S_out