import numpy as np
import cupy as cp
import scipy.sparse as sp
import cupyx.scipy.sparse as cpsp
import pandas as pd
import time

# Calculate X.T @ X, X should be a 2D numpy array
def gpu_block_matmul(X):

    partition1 = int((X.shape[0] + 4) / 5)
    partition2 = int((X.shape[1] + 4) / 5)
    row_seg = [0, partition1, 2 * partition1, 3 * partition1, 4 * partition1, X.shape[0]]
    col_seg = [0, partition2, 2 * partition2, 3 * partition2, 4 * partition2, X.shape[1]]

    R = []

    for j in range(5):

        A1 = cp.asarray(X[row_seg[0]:row_seg[1], col_seg[j]:col_seg[j+1]])
        R0 = A1.T @ A1
        A2 = cp.asarray(X[row_seg[1]:row_seg[2], col_seg[j]:col_seg[j+1]])
        R0 += A2.T @ A2
        A3 = cp.asarray(X[row_seg[2]:row_seg[3], col_seg[j]:col_seg[j+1]])
        R0 += A3.T @ A3
        A4 = cp.asarray(X[row_seg[3]:row_seg[4], col_seg[j]:col_seg[j+1]])
        R0 += A4.T @ A4
        A5 = cp.asarray(X[row_seg[4]:row_seg[5], col_seg[j]:col_seg[j+1]])
        R0 += A5.T @ A5

        R.append(cp.asnumpy(R0))
        del R0

        for i in range(5):
            if i > j:
                B1 = cp.asarray(X[row_seg[0]:row_seg[1], col_seg[i]:col_seg[i+1]])
                R1 = A1.T @ B1
                B2 = cp.asarray(X[row_seg[1]:row_seg[2], col_seg[i]:col_seg[i+1]])
                R1 += A2.T @ B2
                B3 = cp.asarray(X[row_seg[2]:row_seg[3], col_seg[i]:col_seg[i+1]])
                R1 += A3.T @ B3
                B4 = cp.asarray(X[row_seg[3]:row_seg[4], col_seg[i]:col_seg[i+1]])
                R1 += A4.T @ B4
                B5 = cp.asarray(X[row_seg[4]:row_seg[5], col_seg[i]:col_seg[i+1]])
                R1 += A5.T @ B5

                R.append(cp.asnumpy(R1))
                del R1
                del B1
                del B2
                del B3
                del B4
                del B5

        del A1
        del A2
        del A3
        del A4
        del A5

    V1 = np.concatenate((R[0], R[1], R[2], R[3], R[4]), axis = 1)
    V2 = np.concatenate((R[1].T, R[5], R[6], R[7], R[8]), axis = 1)
    V3 = np.concatenate((R[2].T, R[6].T, R[9], R[10], R[11]), axis = 1)
    V4 = np.concatenate((R[3].T, R[7].T, R[10].T, R[12], R[13]), axis = 1)
    V5 = np.concatenate((R[4].T, R[8].T, R[11].T, R[13].T, R[14]), axis = 1)
    V = np.concatenate((V1, V2, V3, V4, V5), axis = 0)

    return V


    '''
    A_gpu = cp.asarray(X[0:partition1, 0:partition2])
    A1_gpu = A_gpu.T @ A_gpu
    B_gpu = cp.asarray(X[0:partition1, partition2:])
    E1_gpu = A_gpu.T @ B_gpu
    B1_gpu = B_gpu.T @ B_gpu
    C_gpu = cp.asarray(X[partition1:, 0:partition2])
    C1_gpu = C_gpu.T @ C_gpu
    D_gpu = cp.asarray(X[partition1:, partition2:])
    F1_gpu = C_gpu.T @ D_gpu
    D1_gpu = D_gpu.T @ D_gpu

    A1_gpu += C1_gpu
    B1_gpu += D1_gpu
    E1_gpu += F1_gpu

    A = cp.asnumpy(A1_gpu)
    E1 = cp.asnumpy(E1_gpu)
    R1 = np.concatenate((A, E1), axis = 1)
    B = cp.asnumpy(B1_gpu)
    R2 = np.concatenate((E1.T, B), axis = 1)
    R = np.concatenate((R1, R2), axis = 0)

    return R
    '''


# Calculate X^{-1}, X should be a 2D numpy array
def gpu_block_inv(X):

    partition1 = int(X.shape[0] / 2)
    partition2 = int(X.shape[1] / 2)

    A_gpu = cp.asarray(X[0:partition1, 0:partition2])
    A_inv_gpu = cp.linalg.inv(A_gpu)
    B_gpu = cp.asarray(X[0:partition1, partition2:])
    B1_gpu = A_inv_gpu @ B_gpu
    C_gpu = cp.asarray(X[partition1:, 0:partition2])
    C1_gpu = C_gpu @ A_inv_gpu
    D_gpu = cp.asarray(X[partition1:, partition2:])

    D1_inv_gpu = cp.linalg.inv(D_gpu - C_gpu @ B1_gpu)
    C2_gpu = D1_inv_gpu @ C1_gpu
    B2_gpu = B1_gpu @ D1_inv_gpu
    A1_gpu = A_inv_gpu + B1_gpu @ C2_gpu
        
    A = cp.asnumpy(A1_gpu)
    B = cp.asnumpy(- B2_gpu)
    C = cp.asnumpy(- C2_gpu)
    D = cp.asnumpy(D1_inv_gpu)

    del A_gpu
    del B_gpu
    del C_gpu
    del D_gpu
    del A1_gpu
    del B1_gpu
    del C1_gpu
    del D1_inv_gpu
    del A_inv_gpu
    del B2_gpu
    del C2_gpu

    R1 = np.concatenate((A, B), axis = 1)
    R2 = np.concatenate((C, D), axis = 1)
    R = np.concatenate((R1, R2), axis = 0)

    return R



# Calculate X @ Y. Only split X into blocks
def gpu_block_matmulxy(X, Y):

    partition = int((X.shape[0] + 15) / 16)
    row_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition,
                6 * partition, 7 * partition, 8 * partition, 9 * partition, 10 * partition,
                11 * partition, 12 * partition, 13 * partition, 14 * partition, 15 * partition,
                X.shape[0]]

    Y_gpu = cp.asarray(Y)

    R = []

    for i in range(16):
        X_gpu = cp.asarray(X[row_seg[i]:row_seg[i + 1], :])
        X_gpu @= Y_gpu
        R.append(cp.asnumpy(X_gpu))
        del X_gpu

    del Y_gpu

    return np.concatenate((R[0], R[1], R[2], R[3], R[4], R[5], 
                            R[6], R[7], R[8], R[9], R[10], R[11],
                            R[12], R[13], R[14], R[15]), axis = 0)



# Calculate X @ Y. Only split Y into blocks
def gpu_block_matmulxy1(X, Y):

    partition = int((Y.shape[1] + 5) / 6)
    col_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition, Y.shape[1]]

    X_gpu = cp.asarray(X)

    R = []

    for i in range(6):
        Y_gpu = cp.asarray(Y[:, col_seg[i]:col_seg[i + 1]])
        Y_gpu = X_gpu @ Y_gpu
        R.append(cp.asnumpy(Y_gpu))
        del Y_gpu

    del X_gpu

    return np.concatenate((R[0], R[1], R[2], R[3], R[4], R[5]), axis = 1)



# Calculate X^T @ Y. X and Y should be of the same size. Split both X and Y into blocks
def gpu_block_matmulxy_large(X, Y):

    partition1 = int((X.shape[0] + 4) / 5)
    partition2 = int((X.shape[1] + 4) / 5)
    row_seg = [0, partition1, 2 * partition1, 3 * partition1, 4 * partition1, X.shape[0]]
    col_seg = [0, partition2, 2 * partition2, 3 * partition2, 4 * partition2, X.shape[1]]

    R = []

    for j in range(5):

        A1 = cp.asarray(X[row_seg[0]:row_seg[1], col_seg[j]:col_seg[j+1]])
        R0 = A1.T @ A1
        A2 = cp.asarray(X[row_seg[1]:row_seg[2], col_seg[j]:col_seg[j+1]])
        R0 += A2.T @ A2
        A3 = cp.asarray(X[row_seg[2]:row_seg[3], col_seg[j]:col_seg[j+1]])
        R0 += A3.T @ A3
        A4 = cp.asarray(X[row_seg[3]:row_seg[4], col_seg[j]:col_seg[j+1]])
        R0 += A4.T @ A4
        A5 = cp.asarray(X[row_seg[4]:row_seg[5], col_seg[j]:col_seg[j+1]])
        R0 += A5.T @ A5

        R.append(cp.asnumpy(R0))
        del R0

        for i in range(5):
            if i > j:
                B1 = cp.asarray(X[row_seg[0]:row_seg[1], col_seg[i]:col_seg[i+1]])
                R1 = A1.T @ B1
                B2 = cp.asarray(X[row_seg[1]:row_seg[2], col_seg[i]:col_seg[i+1]])
                R1 += A2.T @ B2
                B3 = cp.asarray(X[row_seg[2]:row_seg[3], col_seg[i]:col_seg[i+1]])
                R1 += A3.T @ B3
                B4 = cp.asarray(X[row_seg[3]:row_seg[4], col_seg[i]:col_seg[i+1]])
                R1 += A4.T @ B4
                B5 = cp.asarray(X[row_seg[4]:row_seg[5], col_seg[i]:col_seg[i+1]])
                R1 += A5.T @ B5

                R.append(cp.asnumpy(R1))

                del R1
                del B1
                del B2
                del B3
                del B4
                del B5

        del A1
        del A2
        del A3
        del A4
        del A5

    V1 = np.concatenate((R[0], R[1], R[2], R[3], R[4]), axis = 1)
    V2 = np.concatenate((R[1].T, R[5], R[6], R[7], R[8]), axis = 1)
    V3 = np.concatenate((R[2].T, R[6].T, R[9], R[10], R[11]), axis = 1)
    V4 = np.concatenate((R[3].T, R[7].T, R[10].T, R[12], R[13]), axis = 1)
    V5 = np.concatenate((R[4].T, R[8].T, R[11].T, R[13].T, R[14]), axis = 1)
    V = np.concatenate((V1, V2, V3, V4, V5), axis = 0)

    return V



# Calculate the Frobenius norm square of X @ Y. Only split X into blocks.
def gpu_block_norm(X, Y):

    partition = int((X.shape[0] + 15) / 16)
    row_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition,
                6 * partition, 7 * partition, 8 * partition, 9 * partition, 10 * partition,
                11 * partition, 12 * partition, 13 * partition, 14 * partition, 15 * partition,
                X.shape[0]]

    Y_gpu = cp.asarray(Y)

    R = []

    for i in range(16):
        X_gpu = cp.asarray(X[row_seg[i]:row_seg[i + 1], :])
        X_gpu @= Y_gpu
        R.append(cp.asnumpy(cp.linalg.norm(X_gpu) ** 2))
        del X_gpu
        print(i)

    return np.sum(R)

    #return R[0] + R[1] + R[2] + R[3] + R[4] + R[5] + R[6] + R[7] + R[8] + R[9] + R[10] + R[11] + 
    #        R[12] + R[13] + R[14] + R[15]



# Calculate the Frobenius norm square of X - Y.
def gpu_block_norm1(X, Y):

    partition = int((X.shape[0] + 15) / 16)
    row_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition,
                6 * partition, 7 * partition, 8 * partition, 9 * partition, 10 * partition,
                11 * partition, 12 * partition, 13 * partition, 14 * partition, 15 * partition,
                X.shape[0]]

    R = []

    for i in range(16):
        X_gpu = cp.asarray(X[row_seg[i]:row_seg[i + 1], :])
        Y_gpu = cp.asarray(Y[row_seg[i]:row_seg[i + 1], :])
        X_gpu -= Y_gpu
        R.append(cp.asnumpy(cp.linalg.norm(X_gpu) ** 2))
        del X_gpu
        #print(i)

    return np.sum(R)



# Calculate the Frobenius norm square of X * y (X @ diag(y)).
def gpu_block_norm2(X, y):

    partition = int((X.shape[0] + 15) / 16)
    row_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition,
                6 * partition, 7 * partition, 8 * partition, 9 * partition, 10 * partition,
                11 * partition, 12 * partition, 13 * partition, 14 * partition, 15 * partition,
                X.shape[0]]

    R = []

    y_gpu = cp.asarray(y)

    for i in range(16):
        X_gpu = cp.asarray(X[row_seg[i]:row_seg[i + 1], :])
        X_gpu = X_gpu * y_gpu
        R.append(cp.asnumpy(cp.linalg.norm(X_gpu) ** 2))
        del X_gpu
        #print(i)

    return np.sum(R)



# Calculate the Frobenius norm square of each column of X
def gpu_block_colnorm(X):

    partition = int((X.shape[0] + 15) / 16)
    row_seg = [0, partition, 2 * partition, 3 * partition, 4 * partition, 5 * partition,
                6 * partition, 7 * partition, 8 * partition, 9 * partition, 10 * partition,
                11 * partition, 12 * partition, 13 * partition, 14 * partition, 15 * partition,
                X.shape[0]]
    res_gpu = cp.zeros(X.shape[1])

    for i in range(16):
        X_gpu = cp.asarray(X[row_seg[i]:row_seg[i + 1], :])
        res_gpu += cp.linalg.norm(X_gpu, axis = 0) ** 2
        del X_gpu

    return cp.asnumpy(res_gpu)
