import numpy as np
import pandas as pd
import scipy.sparse as sp
from math import exp, log, sqrt
import cupy as cp
import accelib1 as accelib
import time

start = time.time()


df_X = pd.read_csv("./data/ML20M1/ratings_processed_test1.csv")
df_Y = pd.read_csv("./data/ML20M1/ratings_processed_test2.csv")

arr_X = df_X.to_numpy()
arr_Y = df_Y.to_numpy()

X = sp.csr_matrix((np.squeeze(arr_X[:, 2]), (np.squeeze(arr_X[:, 0].astype(int)),
                    np.squeeze(arr_X[:, 1].astype(int)))),
                    shape = (41549, 26744), dtype = "float32").todense() # ML20M1
                    #shape = (144058, 17770), dtype = "float32").todense() # NETFLIX
                    #shape = (769365, 40000), dtype = "float32").todense() # YELP2018
                    #shape = (305396, 40000), dtype = "float32").todense() # MSD

Y = sp.csr_matrix((np.squeeze(arr_Y[:, 2]), (np.squeeze(arr_Y[:, 0].astype(int)),
                    np.squeeze(arr_Y[:, 1].astype(int)))),
                    shape = (41549, 26744), dtype = "float32").todense() # ML20M1
                    #shape = (144058, 17770), dtype = "float32").todense() # NETFLIX
                    #shape = (769365, 40000), dtype = "float32").todense() # YELP2018
                    #shape = (305396, 40000), dtype = "float32").todense() # MSD

m = X.shape[0]

#Load Data Matrix
XX = np.load("./data/ML20M1/matrix1/XX.npy")
YX = np.load("./data/ML20M1/matrix1/YX.npy")
Sxx = np.load("./data/ML20M1/matrix1/Sxx.npy")
Syy = np.load("./data/ML20M1/matrix1/Syy.npy")
Q_h = np.load("./data/ML20M1/matrix1/Q_h.npy")
B = np.load("./data/ML20M1/matrix1/B.npy")
B1 = np.load("./data/ML20M1/matrix1/B1.npy")

#Load Weight
W = np.load("./model2/ease_train_ML20M1_50.npy")

#print(W.dtype)

I = np.eye(W.shape[0])
#sigma = 0.01
sigma = 0.001
sigma2 = sigma ** 2
#m = s1.shape[0]


mempool = cp.get_default_memory_pool()

B_gpu = cp.asarray(B)
c = np.trace(Syy) - cp.asnumpy(cp.linalg.norm(B_gpu) ** 2)
print("c: ", c)
del B_gpu

print(time.time() - start)

# Decompose Sxx
Sxx_gpu = cp.asarray(Sxx)
L_gpu, S_gpu = cp.linalg.eigh(Sxx_gpu)

del Sxx_gpu

for i in range(len(L_gpu)):
    if L_gpu[i] < 0:
        L_gpu[i] = 0

S = cp.asnumpy(S_gpu)
L = cp.asnumpy(L_gpu)

B0 = accelib.gpu_block_matmulxy1(S.T, W - B1.T)
sum_B = accelib.gpu_block_colnorm(B0.T)

W = W.T  # Set W to left

res = []
lmd = 512
for t in range(10):
    
    # Compute Variance
    s = 1 / ( 2 * lmd / m * np.diag(XX) + 1 / sigma2 )
    
    # Compute Expectation
    
    #W = W.T   # W left
    tp = 1 / (2 * lmd * sigma2)
    W1 = tp * W + YX
    M1 = XX + tp * I

    inv = accelib.gpu_block_inv(M1)
    inv_gpu = cp.asarray(inv)
    print(mempool.used_bytes())
    W1_gpu = cp.asarray(W1)
    U_gpu = W1_gpu @ inv_gpu
    del W1_gpu
    Temp_gpu = cp.divide( cp.diag(U_gpu), cp.diag(inv_gpu) )
    inv_gpu = inv_gpu * Temp_gpu
    U_gpu = U_gpu - inv_gpu
    del Temp_gpu
    del inv_gpu
    print(mempool.used_bytes())
    #I_gpu = cp.asarray(I)
    U = cp.asnumpy(U_gpu)
    del U_gpu

    print(mempool.used_bytes())


    # Compute KL
    U_norm = accelib.gpu_block_norm1(U, W)
    KL = 0.5 * ( (W.shape[0] ** 2 - W.shape[0]) * ( 2 * log(sigma) - 1) - (W.shape[0] - 1) * np.sum(np.log(s) - s / sigma2) + U_norm / sigma2 )
    print("KL: ", KL)


    # Compute True Risk
    s *= (W.shape[0] - 1)
    SgU = accelib.gpu_block_matmulxy(U, Q_h)
    add = accelib.gpu_block_norm2(Q_h, np.sqrt(s))
    true = accelib.gpu_block_norm1(SgU, B) + add + c
    print("true: ", true)


    # Compute Empirical Risk
    #U = U.T
    U_gpu = cp.asarray(U)
    U_gpu = U_gpu.T
    U = cp.asnumpy(U_gpu)
    del U_gpu
    norm0 = accelib.gpu_block_norm_yxu(Y, X, U) / m
    add = accelib.gpu_block_norm2(X, np.sqrt(s)) / m
    emp = norm0 + add

    print("emp: ", emp)

    
    # Compute Phi
    phi = lmd * c
    for i in range(W.shape[0]):
        phi += lmd * L[i] * sum_B[i] / (1 - 2 * lmd * L[i] * sigma2) - W.shape[0] / 2 * np.log(1 - 2 * lmd * L[i] * sigma2)
    print("phi: ", phi)

    
    # Compute RH of the bound

    #print(mempool.used_bytes())

    res0 = emp + (KL + np.log(10 / 0.01) + phi) / lmd
    res.append(res0)

    print("iter: ", t, "   lambda: ", lmd, "   result: ", res0)
    lmd *= 2


print(res)
print(min(res))


