import torch
import torch.nn.functional as F
import numpy as np
import sys
import time

### Parameters

inputfile = sys.argv[1]
k = int(sys.argv[2]) # Target low rank
m = int(sys.argv[3]) # Sketching dimension

eta = 1 # SGD learning rate
if len(sys.argv) > 4:

sgd_iters = 50
if len(sys.argv) > 6:
    sgd_iters = int(sys.argv[6])

### Load data

try:
    dataset, testset = torch.load(inputfile)
except:
    dataset, testset, bla = torch.load(inputfile)

n = len(dataset)
q = len(testset)
d1, d2 = dataset[0].size()

print("Dataset size:", n)
print("Test set size:", q)
print("Dimensions:", d1,d2)
print("Target low rank:", k)
print("Sketching dimension:", m)

### Sketch-based low rank approximation

def np_svd(A):
    U, Sigma, V = torch.svd(A, some=False)
    return U, Sigma, V.T

def opt_low_rank(A, k):
    U, Sigma, V = torch.svd(A)
    Vh = V.T
    return U[:,:k].cuda() @ torch.diag(Sigma)[:k,:k].cuda() @ Vh[:k,:].cuda()

def SCW(S, A, k):
    Vh = np_svd(S @ A)[2][:S.size()[0],:].detach()
    return opt_low_rank(A @ Vh.T, k) @ Vh

def get_loss(S, test_set, k):
    loss = 0
    for A in test_set:
        Ac = A.cuda()
        loss += torch.norm(Ac - SCW(S, Ac, k)).cpu()
        del Ac
    return float(loss*1./len(test_set))
        

### Train

u_dic = {}

opt_train_loss_dic = {}
opt_test_loss_dic = {}

if inputfile in opt_train_loss_dic:
    opt_train_loss = opt_train_loss_dic[inputfile]
else:
    print("Computing optimal training loss...")
    opt_train_loss = 0
    for A in dataset:
        opt_train_loss += torch.norm(A.cuda() - opt_low_rank(A, k))
    opt_train_loss *= 1./n
print("Optimal training loss:", opt_train_loss)

if inputfile in opt_test_loss_dic:
    opt_loss = opt_test_loss_dic[inputfile]
else:
    print("Computing optimal test loss...")
    opt_loss = 0
    for A in testset:
        opt_loss += torch.norm(A.cuda() - opt_low_rank(A, k))
    opt_loss *= 1./q
print("Optimal test loss:", opt_loss)


I0 = torch.zeros(k,d1).cuda()
I0[:,:k] = torch.eye(k)
I0 = I0.cuda()

J0 = torch.eye(d1).cuda()

S0 = torch.zeros(m,d1).cuda()
for j in range(d1):
    for i in np.random.permutation(m)[:sparsity]:
        S0[i,j] = 1./(sparsity**0.5)

vals = 2.*torch.randint(0,2,(1,d1)).view(-1).float() - torch.ones(d1)
orig_vals = torch.zeros(d1)
orig_vals += vals
vals = vals.cuda()


for t in range(sgd_iters):
    if t % n == 0:
        perm = np.random.permutation(n)
    if t>0:
        print("===" + str(t) + "===")
        current_loss = get_loss(S, testset, k) - opt_loss
        del S
        print("Test loss at iteration " + str(t) + ": " + str(current_loss))

    # Get U for SGD step
    idx = perm[t % n]
    if idx in u_dic:
        U = u_dic[idx].cuda()
    else:
        Ac = dataset[idx].cuda()
        U = torch.svd(Ac, some=False)[0]
        u_dic[idx] = U.cpu()
        del Ac

    # Do SGD
    vals.requires_grad_(True)
    S = S0 @ torch.diag(vals)
    SU = S @ U
    loss = torch.norm((SU[:,:k].T @ SU @ J0) - I0)
    loss.backward()
    vals.requires_grad_(False)
    vals -= eta * vals.grad
    vals.detach_()
    del U
    del SU
    del loss

print("Final loss:", get_loss(S, testset, k) - opt_loss)

