import time
import torch
import matplotlib.pyplot as plt

from src import *

def compute_eigenspectrum(U_arr):
    # U is a tensor of shape (n_samples, dimension, rank)
    # need to compute eigenspectrum of average P=UU^T
    # without explicitly computing P and summing
    
    n, d, k = U_arr.shape
    # start with canonical basis
    Q = torch.zeros(d, d).double()
    for i in range(d):
        e_i = torch.zeros(d).double()
        e_i[i] = 1
        v = 0.
        for j in range(n):
            v += U_arr[j, :, :] @ (U_arr[j, :, :].T @ e_i)
        Q[i, :] = v
    Q /= n

    eigvals, eigvecs = torch.linalg.eigh(Q)
    return eigvals

def cdf(eigvals):
    xx = torch.linspace(0, 1, 1000)
    yy = [torch.sum(eigvals < x) / eigvals.numel() for x in xx]
    return xx, yy


if __name__ == '__main__':
    d = 1000
    k = 100
    n = 20
    radius = 0.5
    U_arr = get_random_clustered_grassmannian_points(d, k, n, radius=radius)
    print(U_arr.shape)

    start = time.time()
    P_ave = torch.zeros(d, d)
    for U in U_arr:
        P_ave += U @ U.T
    P_ave /= n
    eigvals, eigvecs = torch.linalg.eigh(P_ave)
    end = time.time()
    print(f'Time taken: {end - start}')

    start = time.time()
    fast_eigvals = compute_eigenspectrum(U_arr)
    end = time.time()
    print(f'Time taken fast version: {end - start}')

    plt.figure()
    plt.plot(*cdf(eigvals))
    plt.plot(*cdf(fast_eigvals))
    plt.show()
