import numpy as np
import random

def kg_binary(A: np.array, kmax: int):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)    
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]): 
        for j in range(A.shape[1]):
            if A[i, j] != 0: 
                # Y[i, j] = random.gauss(1, 1)
                Y[i, j] = 1
    for k in range(kmax):
        for i in range(A.shape[0]):
            sum_row_y = np.sum(A[i]*Y[i])
            num_ones_y = np.sum(np.abs(A[i]))
            for j in range(A.shape[1]):
                if A[i][j] != 0:
                    X[i][j] = sum_row_y - A[i][j] * Y[i][j]
                else:
                    X[i][j] = 0
            X[i] = X[i]/(num_ones_y-1)
        for j in range(A.shape[1]):
            sum_col_x = np.sum(A[:,j]*X[:,j])
            num_ones_x = np.sum(np.abs(A[:,j]))
            for i in range(A.shape[0]):
                if A[i][j] != 0:
                    Y[i][j] = sum_col_x - A[i][j] * X[i][j]
                else:
                    Y[i][j] = 0
            Y[:,j] = Y[:,j]/(num_ones_x-1)
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x

def kg_binary_with_init(A: np.array, kmax: int, init: np.array):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]): 
        for j in range(A.shape[1]):
            if A[i, j] != 0:
                Y[i, j] = init[j]
    for k in range(kmax):
        for i in range(A.shape[0]):
            sum_row_y = np.sum(A[i]*Y[i])
            num_ones_y = np.sum(np.abs(A[i]))
            for j in range(A.shape[1]):
                if A[i][j] != 0:
                    X[i][j] = sum_row_y - A[i][j] * Y[i][j]
                else:
                    X[i][j] = 0
            X[i] = X[i]/(num_ones_y-1)
        for j in range(A.shape[1]):
            sum_col_x = np.sum(A[:,j]*X[:,j])
            num_ones_x = np.sum(np.abs(A[:,j]))
            for i in range(A.shape[0]):
                if A[i][j] != 0:
                    Y[i][j] = sum_col_x - A[i][j] * X[i][j]
                else:
                    Y[i][j] = 0
            Y[:,j] = Y[:,j]/(num_ones_x-1)
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x        

def kg_binary_with_assist(A: np.array, kmax: int, assist: np.array):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]): 
        for j in range(A.shape[1]):
            if A[i, j] != 0:
                # Y[i, j] = random.gauss(1,1)
                Y[i, j] = 1
    for k in range(kmax):
        for i in range(A.shape[0]):
            sum_row_y = np.sum(A[i]*Y[i])
            num_ones_y = np.sum(np.abs(A[i]))
            for j in range(A.shape[1]):
                if A[i][j] != 0:
                    X[i][j] = sum_row_y - A[i][j] * Y[i][j]
                else:
                    X[i][j] = 0
            X[i] = X[i]/(num_ones_y-1)
        for j in range(A.shape[1]):
            sum_col_x = np.sum(A[:,j]*X[:,j])
            num_ones_x = np.sum(np.abs(A[:,j]))
            for i in range(A.shape[0]):
                if A[i][j] != 0:
                    Y[i][j] = sum_col_x - A[i][j] * X[i][j]
                else:
                    Y[i][j] = 0
            Y[:,j] = Y[:,j]/(num_ones_x-1)
    # merge the iterative results with the assist info
    for j in range(A.shape[1]):
        Y[:j] += assist
        #Y[:j] *= assist
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x

def kg_binary_with_lambda_assist(A: np.array, kmax: int, assist: np.array, lamBda: float):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            if A[i, j] != 0:
                # Y[i, j] = random.gauss(1,1)
                Y[i, j] = 1
    for k in range(kmax):
        for i in range(A.shape[0]):
            sum_row_y = np.sum(A[i]*Y[i])
            num_ones_y = np.sum(np.abs(A[i]))
            for j in range(A.shape[1]):
                if A[i][j] != 0:
                    X[i][j] = sum_row_y - A[i][j] * Y[i][j]
                else:
                    X[i][j] = 0
            X[i] = X[i]/(num_ones_y-1)
        for j in range(A.shape[1]):
            sum_col_x = np.sum(A[:,j]*X[:,j])
            num_ones_x = np.sum(np.abs(A[:,j]))
            for i in range(A.shape[0]):
                if A[i][j] != 0:
                    Y[i][j] = sum_col_x - A[i][j] * X[i][j]
                else:
                    Y[i][j] = 0
            Y[:,j] = Y[:,j]/(num_ones_x-1)
    # merge the iterative results with the assist info
    for i in range(A.shape[0]):
        Y[i] += lamBda*assist
        #Y[:j] *= assist
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x 

def kg_binary_with_iterative_assist(A: np.array, kmax: int, assist: np.array):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            if A[i, j] != 0:
                # Y[i, j] = random.gauss(1,1)
                Y[i, j] = 1
    for k in range(kmax):
        for i in range(A.shape[0]):
            sum_row_y = np.sum(A[i]*Y[i])
            num_ones_y = np.sum(np.abs(A[i]))
            for j in range(A.shape[1]):
                if A[i][j] != 0:
                    X[i][j] = sum_row_y - A[i][j] * Y[i][j]
                else:
                    X[i][j] = 0
            X[i] = X[i]/(num_ones_y-1)
        for j in range(A.shape[1]):
            sum_col_x = np.sum(A[:,j]*X[:,j])
            num_ones_x = np.sum(np.abs(A[:,j]))
            for i in range(A.shape[0]):
                if A[i][j] != 0:
                    Y[i][j] = sum_col_x - A[i][j] * X[i][j] + assist[j]
                else:
                    Y[i][j] = 0
            Y[:,j] = Y[:,j]/(num_ones_x-1)
        #Y[:j] *= assist
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x 


def kg_new(A: np.array):
    X = np.zeros(A.shape, dtype=float)
    Y = np.zeros(A.shape, dtype=float)
    for i in range(A.shape[0]): 
        for j in range(A.shape[1]):
            if A[i, j] != 0:
                # Y[i, j] = random.gauss(1, 1)
                Y[i, j] = 1
    return X, Y

def kg_singlestep(A: np.array, X: np.array, Y: np.array):
    for i in range(A.shape[0]):
        sum_row_y = np.sum(A[i]*Y[i])
        num_ones_y = np.sum(np.abs(A[i]))
        for j in range(A.shape[1]):
            if A[i][j] != 0:
                X[i][j] = sum_row_y - A[i][j] * Y[i][j]
            else:
                X[i][j] = 0
        X[i] = X[i]/(num_ones_y-1)
    for j in range(A.shape[1]):
        sum_col_x = np.sum(A[:,j]*X[:,j])
        num_ones_x = np.sum(np.abs(A[:,j]))
        for i in range(A.shape[0]):
            if A[i][j] != 0:
                Y[i][j] = sum_col_x - A[i][j] * X[i][j]
            else:
                Y[i][j] = 0
        Y[:,j] = Y[:,j]/(num_ones_x-1)
    # print("X:\n",X)
    # print("Y:\n",Y)
    return X, Y
        
def kg_ending(A: np.array, Y: np.array):
    final_x = np.zeros(A.shape[0], dtype=float)
    for i in range(A.shape[0]):
        final_x[i] = np.sum(A[i]*Y[i])
    
    return final_x