import torch
import numpy as np
import sys
import time

import warnings
warnings.filterwarnings('ignore')


### Parameters


inputfile = sys.argv[1]
k = int(sys.argv[2]) # Target low rank
m = int(sys.argv[3]) # Sketching dimension

### Load data
dataset, testset = torch.load(inputfile)
dmin = min(dataset[0].size())
n = len(dataset)
q = len(testset)
d1, d2 = dataset[0].size()

iters = n

print("Dataset size:", n)
print("Test set size:", q)
print("Dimensions:", d1,d2)
print("Target low rank:", k)
print("Sketching dimension:", m)

### Sketch-based low rank approximation

def np_svd(A):
    U, Sigma, V = torch.svd(A, some=False)
    return U, Sigma, V.T

def opt_low_rank(A, k):
    U, Sigma, V = torch.svd(A, some=False)
    Vh = V.T
    return U[:,:k].cuda() @ torch.diag(Sigma)[:k,:k].cuda() @ Vh[:k,:].cuda()

def SCW(S, A, k):
    Vh = np_svd(S @ A)[2][:S.size()[0],:].detach()
    return opt_low_rank(A @ Vh.T, k) @ Vh

def get_loss(S, test_set, k):
    loss = 0
    for A in test_set:
        Ac = A.cuda()
        loss += torch.norm(Ac - SCW(S, Ac, k)).cpu()
        del Ac
    return float(loss*1./len(test_set))

def get_one_loss(S, A, k):
    Ac = A.cuda()
    loss = torch.norm(Ac - SCW(S, Ac, k)).cpu()
    del Ac
    return float(loss)


### Train

opt_train_loss_dic = {}
opt_test_loss_dic = {}

if inputfile in opt_train_loss_dic:
    opt_train_loss = opt_train_loss_dic[inputfile]
else:
    print("Computing optimal training loss...")
    opt_train_loss = 0
    for A in dataset:
        opt_train_loss += torch.norm(A.cuda() - opt_low_rank(A, k))
    opt_train_loss *= 1./n
print("Optimal training loss:", opt_train_loss)

if inputfile in opt_test_loss_dic:
    opt_loss = opt_test_loss_dic[inputfile]
else:
    print("Computing optimal test loss...")
    opt_loss = 0
    for A in testset:
        opt_loss += torch.norm(A.cuda() - opt_low_rank(A, k))
    opt_loss *= 1./q
print("Optimal test loss:", opt_loss)

S0 = torch.zeros(m,d1).byte().cuda()
for j in range(d1):
    for i in np.random.permutation(m)[:sparsity]:
        S0[i,j] = 1

sumA = dataset[np.random.randint(n)].cuda()
S = torch.zeros(2*m, d1).cuda()

for iii in range(m):
    sumAsub = sumA[torch.ByteTensor(list(S0[iii,:]))]
    subU, subSigma, subV = torch.svd(sumAsub, some=False)
    subsigvals = subSigma.cpu().numpy()[1:]
    sqrsigvals = np.square(subsigvals)
    sqrnormsigvals = sqrsigvals *1. / np.sum(sqrsigvals)
    subsigvec = np.random.choice(np.arange(len(sqrnormsigvals)), p=sqrnormsigvals) # Pick singular vector w.p. proportional to square singular value
    ccc = 0
    S[2*iii,S0[iii,:]] = subU[:,0] # This is the top sigvec
    S[2*iii+1,S0[iii,:]] = subU[:,subsigvec + 1] # This is another random sigvec

S = S0 @ torch.diag(vals).cuda()

print("Loss:", get_loss(S, testset, k) - opt_loss)
