""" Lewis weight sampling. """
import numpy as np

# Note that we are using the pseudo-inverse for the case
# where A does not have full rank. Verified that this does
# the right thing (removes columns of U corresponding to 0
# singular values) through SVD calculations. We also use the
# pseudo-inverse at the beginning to deal with the edge case
# where w has 0 as an entry.
def lewis_iterate(A, p, w):
    W = np.diag(w)
    W = W ** (2/p - 1)
    W = np.linalg.pinv(W)
    n = len(w)
    M = np.linalg.pinv(np.matmul(np.matmul(A.T, W), A))
    for i in range(n):
        w[i] = np.matmul(np.matmul(A[i], M), A[i].T) ** (p / 2)
    return w


def approximate_lewis_weights(A, p, T):
    n = A.shape[0]
    w = np.ones(n)
    for _ in range(T):
        w = lewis_iterate(A, p, w)
    return w


def generate_OSNAP_sparse_embedding(n_row, n_col, s):
    emb = np.zeros((n_row, n_col))
    for c in range(n_col):
        sel_idx = np.random.choice(np.arange(n_row), s, replace=False)
        sign = np.random.randint(2, size=s) * 2 - 1
        entry = sign * 1 / np.sqrt(s)
        emb[sel_idx, c] = entry
    return emb


def l1_lewis_weights(A, approx_factor=100):
    """
    Row wise Lewis weights for p = 1.
    :param A: data matrix, size n x d
    :return: row wise Lewis weight, size n
    """
    T = int(approx_factor * np.log(np.log(A.shape[0])))
    T = max(T, 10)
    # print('# trials T: {}'.format(T))
    return approximate_lewis_weights(A, 1, T)

# lewis_weights are the lp lewis weights https://arxiv.org/pdf/1412.0588.pdf
# Theoretically, if sample_rows is sufficiently large, then we obtain an lp
# subspace embedding
def get_lewis_weight_sampling_matrix(A, lewis_weights, sample_rows, p):
    n, _ = A.shape
    assert(len(lewis_weights) == n)

    # Rescale lewis_weights so that they sum to N = sample_rows
    sampling_values = sample_rows * lewis_weights/sum(lewis_weights)
    probabilities = sampling_values/sum(sampling_values)
    sel_indices = []
    S = np.zeros((sample_rows, n))
    for i in range(sample_rows):
        j = np.random.choice(n, 1, p=probabilities)
        S[i, j] = 1 / (sampling_values[j] ** (1/p))
        sel_indices.append(j)
    sel_indices = np.array(sel_indices).ravel()
    return S, sel_indices

def perform_l1_lewis_weight_sampling(A, sample_rows, approx_factor=25):
    lewis_weights = l1_lewis_weights(A, approx_factor)
    S, sel_indices = get_lewis_weight_sampling_matrix(A, lewis_weights, sample_rows, p=1)
    return S, sel_indices

# Main function for CSS under l_{p, 2} norm (column wise 
# sum of p-th power of Euclidean norm), based on algorithm
# from https://arxiv.org/pdf/1510.06073.pdf
def CSS_l12(A, sketch_size, sparsity, lewis_sample_rows, approx_factor=25):

    # 1. Sparse Embedding Matrix
    emb = generate_OSNAP_sparse_embedding(sketch_size, A.shape[0], sparsity)
    SA = np.matmul(emb, A)

    # 2. Lewis weights sampling
    S, sel_indices = perform_l1_lewis_weight_sampling(SA.T, lewis_sample_rows, approx_factor)
    AS_prime = np.matmul(A, S.T)

    # 3. New sparse embedding matrix for regression
    R = generate_OSNAP_sparse_embedding(sketch_size, A.shape[0], sparsity)

    # 4. Regression - not actually used for distributed protocol
    RA = np.matmul(R, A)
    RAS_prime_inv = np.linalg.pinv(np.matmul(R, AS_prime))
    return np.matmul(RAS_prime_inv, RA), sel_indices

## Check result
def calc_p_norm(A, p):
    return np.sum(np.abs(A) ** p) ** (1/p)


def check_result(A, B, p):
    _, d = A.shape
    rand_x = np.random.rand(d) * 100
    A_p_norm = calc_p_norm(np.matmul(A, rand_x), p)
    B_p_norm = calc_p_norm(np.matmul(B, rand_x), p)
    print('Ax norm {}'.format(A_p_norm))
    print('SAx norm {}'.format(B_p_norm))
    return A_p_norm / B_p_norm

"""
Obsolete:

def find_p(w, eps=0.2):
    p1 = w / (eps ** 2)
    N = np.sum(p1)
    # print('N: ', N)
    p2 = p1 * np.log(N)
    # print(np.log(N))
    # print('p1, p2, ', p1, p2)
    factor = 2
    while np.any(p1 < p2):
        p2 = p2 * factor
        N = np.sum(np.sum(p2))
        p1 = w / (eps ** 2) * factor * np.log(N)
        factor *= 2
    return p2

## main function for generating leverage score sampling matrix S
def approx_leverage_score_sampling(A, p, eps=0.2):
    n, d = A.shape
    sample_rows = int(d * np.log(d) / (eps ** 2))
    T = int(np.log(n)) + 2
    w = approximate_lewis_weights(A, p, T)
    raw_prob = find_p(w)
    S, _ = get_lewis_weight_sampling_matrix(A, raw_prob, sample_rows, p)
    return S
"""


if __name__ == '__main__':
    # Test approximation factor
    n, d = 100, 50
    A = np.random.rand(n, d) + 0.001
    A = A * 1000
    S, _ = perform_l1_lewis_weight_sampling(A, 1000)
    x = np.random.rand(d) * 1000
    approx_factor = check_result(A, S @ A, 1)
    print('approx factor Ax / SAx {}'.format(approx_factor))

    # Test values of Lewis weights
    # 2nd and 3rd row Lewis weights should be 5 times
    # those of the next 3 columns.
    M_test = [[1, 2, 3], [5, 0, 0], [5, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0]]
    print(M_test)
    M_test = np.array(M_test)
    print(M_test.shape)
    print(l1_lewis_weights(M_test))

    """
    n, d = 100, 50
    A = np.random.rand(n, d) + 0.0001
    A = A * 1000
    T = int(np.log(n)) + 2
    # T = 10
    p = 3
    w = approximate_lewis_weights(A, p, T)
    print('A: ', A)
    print('w: ', w)
    print(np.sum(w))
    w_ = w ** (1/2 - 1/p)
    B = np.matmul(np.diag(w_), A)  # B = SA
    print('B: ', B)
    print('B shape: ', B.shape)
    # approx_factor = check_result(A, B, p)
    # print(approx_factor)

    eps = 0.2
    raw_prob = find_p(w, eps=eps)
    print('raw prob: {}'.format(raw_prob))
    sample_rows = int(d * int(np.log(d)) / (eps ** 2))
    print('sample rows : ', sample_rows)
    S, _ = get_lewis_weight_sampling_matrix(A, raw_prob, sample_rows, p)
    print(S, S.shape)
    approx_factor = check_result(A, np.matmul(S, A), p)
    print('approx factor Ax / SAx {}'.format(approx_factor))
    """
