import numpy as np
import random
import copy

# import time


def low_rank(C, k):
    n, d = C.shape
    u, sigma, v = np.linalg.svd(C, full_matrices=False)
    t = sigma.shape
    s = np.zeros(t)
    s[0:k] = sigma[0:k]
    Ck = u @ np.diag(s) @ v

    return Ck


def countSketch(m, n):
    S = np.zeros((m, n))
    for i in range(n):
        p = np.random.randint(0, m)
        v = 2 * np.random.randint(0, 2) - 1
        S[p, i] = v #1
    return S



def low_rank_appro(A, k, S):
    n, d = A.shape
    SA = S @ A
    u, s, v = np.linalg.svd(SA, full_matrices=False)
    AV = A @ v.T
    AVk = low_rank(AV, k)
    return AVk @ v



def new_idea_combine(A, k, m):
    n, d = A.shape
    u, s, v = np.linalg.svd(A, full_matrices=0)
    val = np.zeros(n)
    pos = np.zeros(n)
    value = np.zeros(n)
    for i in range(n):
        pos[i] = -1
        value[i] = 1


    lamb = 10
    C = A.T @ A + lamb * np.eye(d)
    # print("lamb", lamb)
    C = np.linalg.pinv(C)
    for i in range(n):
        a = A[i, :]
        val[i] = a.T @ C @ a
    val /= np.sum(val)


    _Set = np.random.choice(n, m, p=val)

    Sample = np.zeros((m, d))

    t = 0
    for i in _Set:
        value[i] = 1
        Sample[t] = value[i] * A[i] / np.linalg.norm(A[i])
        pos[i] = t
        t += 1

    for i in range(n):

        if pos[i] != -1:

            continue

        v = 1 * A[i, :] / np.linalg.norm(A[i, :])
        mv = -1 * A[i, :] / np.linalg.norm(A[i, :])
        r = Sample @ v.T
        mr = Sample @ mv.T
        p = np.argmax(r)
        mp = np.argmax(mr)
        if r[p] >= mr[mp]:
            pos[i] = p
            value[i] = 1
        else:
            pos[i] = mp
            value[i] = -1

    S = np.zeros((m, n))
    for i in range(n):
        S[int(pos[i])][i] = value[i]  # 2 * np.random.randint(0, 2) - 1

    S = new_CS_Matrix(pos, value, m, n, A)


    return S

def new_CS_Matrix(pos, value, m, n, A):

    par = []
    n, d = A.shape
    new_pos = np.zeros(n)
    new_value = np.zeros(n)
    S = np.zeros((m, n))
    for i in range(m):
        par.append([])
    for i in range(n):
        p = pos[i]
        par[int(p)].append(i)
    for i in range(m):
        l = len(par[i])
        sA = np.zeros((l, d))
        for j in range(l):
            sA[j] = value[par[i][j]] * A[par[i][j]]
        u, s, v = np.linalg.svd(sA, full_matrices=False)
        score = u[:, 0:1]
        for j in range(l):
            S[i][par[i][j]] = score[j]
            new_pos[par[i][j]] = i
            new_value[par[i][j]] = score[j]
    return S




def four_sketch(A, k, S1, V1, R1, W1):
    n, d = A.shape

    S = S1
    V = V1

    R = R1
    W = W1

    VAR = V @ A @ R.T
    SAW = S @ A @ W.T
    G = V @ A @ W.T

    U_c, sigma_c, V_c = np.linalg.svd(VAR, full_matrices=False)
    U_d, sigma_d, V_d = np.linalg.svd((SAW).T, full_matrices=False)
    V_c = V_c.T
    V_d = V_d.T
    G_proj = U_c.T @ G @ U_d

    U1, sigma1, V1 = np.linalg.svd(G_proj, full_matrices=False)
    V1 = V1.T
    X_L = U1[:, 0:k] @ np.diag(sigma1[0:k])
    X_R = V1.T[0:k]  # , :

    sig_inv_c = np.diag(1 / sigma_c)
    sig_inv_d = np.diag(1 / sigma_d)
    Z1 = V_c @ sig_inv_c @ X_L
    Z2 = X_R @ sig_inv_d @ (V_d).T

    return A @ R.T @ Z1 @ Z2 @ S @ A


if __name__ == "__main__":

    A =  # test_matrix
    A_train =  # train_matrix

    n, d = A.shape
    # print(A_test.shape)

    m = 40
    k = 10

    S = new_idea_combine(A_train, k, m) # learned sketch matrix
    V = new_idea_combine(A_train, k, 5 * m)
    R = new_idea_combine(A_train.T, k, m)
    W = new_idea_combine(A_train.T, k, 5 * m)

    Ak = low_rank(A, k) # best k-rank approximation
    Ak_appro = low_rank_appro(A, k, S) # one-sketch algorithm
    # Ak_appro = four_sketch(A, k, S, V, R, W)  # four-sketch algorithm

    loss = np.linalg.norm(A - Ak_appro) - np.linalg.norm(A - Ak)

    print(loss)

