import numpy as np
import torch
import wandb
from torch import nn
import cvxpy as cp
import scipy
import argparse

from scipy.sparse.linalg import LinearOperator, eigsh

class MatrixCompletionData:
    
    def __init__(self, M, Omega):
        self.M = M
        self.Omega = Omega
    
    def ystd(self):
        return torch.mean(self.M[self.Omega[0], self.Omega[1]] ** 2) ** 0.5
    
    def eff_w(self, gamma, Us):
        W = None
        for U in Us:
            if W is None:
                W = U
            else:
                W = W @ U
        out = W[self.Omega[0], self.Omega[1]]
        return gamma * W / torch.mean(out ** 2) ** 0.5
    
    def train_loss(self, gamma, Us):
        W = None
        for U in Us:
            if W is None:
                W = U
            else:
                W = W @ U
        out = W[self.Omega[0], self.Omega[1]]
        out = out / torch.mean(out ** 2) ** 0.5
        return ((gamma * out - self.M[self.Omega[0], self.Omega[1]]) ** 2).mean()
    
    def test_loss(self, w):
        return ((w - self.M) ** 2).mean()

    def hess_eigs(self, gamma, Us: torch.Tensor):
        def cur_loss(*UUs):
            return self.train_loss(gamma, UUs)
        def mv(vs):
            len = 0
            Vs = []
            for U in Us:
                n = U.nelement()
                Vs.append(torch.tensor(vs[len : len + n], dtype=U.dtype, device=U.device).reshape_as(U))
                len += n
            return torch.concat([v.view(-1) for v in torch.autograd.functional.hvp(cur_loss, tuple(Us), tuple(Vs))[1]]).cpu().numpy()
        
        dim = sum((U.nelement() for U in Us))
        operator = LinearOperator((dim, dim), matvec=mv, dtype=np.float32)
        evals = eigsh(operator, 3, which='LA', return_eigenvectors=False)
        return np.ascontiguousarray(evals[::-1]).copy()
    
    def min_nuclear_norm_sol(self):
        W = cp.Variable(self.M.shape)
        constraints = [
            W[x, y] == self.M[x, y].item() for x, y in zip(*self.Omega)
        ]
        prob = cp.Problem(cp.Minimize(cp.normNuc(W)), constraints)
        prob.solve()
        return ((W.value - self.M.cpu().numpy()) ** 2).mean(), W.value

    def zero_filling_sol(self):
        sol = np.zeros(self.M.shape)
        for x, y in zip(*self.Omega):
            sol[x, y] = self.M[x, y].item()
        return ((sol - self.M.cpu().numpy()) ** 2).mean(), sol


def get_data(N, d, r, dtype=torch.float32, device='cuda'):
    state = torch.get_rng_state()
    
    torch.manual_seed(7725357)
    with torch.no_grad():
        U = torch.empty(d, r, dtype=dtype, device=device).uniform_(-1, 1)
        V = torch.randn(r, d, dtype=dtype, device=device).uniform_(-1, 1)
        M = U @ V
        M /= (M ** 2).mean() ** 0.5

        li_xy = torch.tensor([[x, y] for x in range(d) for y in range(d)], device=device, dtype=torch.long)
        perm = torch.randperm(li_xy.shape[0], device=device, dtype=torch.long)[:N]
        Omega = (li_xy[perm, 0], li_xy[perm, 1])

        print(M)
        print(Omega)

    torch.set_rng_state(state)

    data = MatrixCompletionData(M, Omega)
    
    assert data.train_loss(data.ystd(), [U, V]) < 1e-6, f'loss: {data.train_loss(data.ystd(), [U, V])}'
    
    return data

def main(N, d, L, r, eta, lam, T=200_000, train_gamma=False, compute_eigen=True):
    wandb.init(
        project='__project_name__',
        name=f'L{L}-N{N}-lr{eta}-wd{lam}-d{d}-r{r}',
        save_code=True,
        config={
            'N': N,
            'd': d,
            'L': L,
            'r': r,
            'lr': eta,
            'wd': lam,
            'T': T,
            'train_gamma': train_gamma
        }
    )

    data = get_data(N, d, r)

    val, sol = data.min_nuclear_norm_sol()
    wandb.summary.update({"min_nuclear_norm_sol/loss/test": val})

    val, sol = data.zero_filling_sol()
    wandb.summary.update({"zero_filling_sol/loss/test": val})

    #torch.manual_seed(325232)

    with torch.no_grad():
        if train_gamma:
            gamma = nn.Parameter(torch.ones([], dtype=torch.float32, device='cuda'), requires_grad=True)
        else:
            gamma = nn.Parameter(data.ystd(), requires_grad=False)
        Us = []
        for l in range(L):
            Us.append(nn.Parameter(torch.randn([d, d], dtype=torch.float32, device='cuda') / d ** 0.5, requires_grad=True))

    for t in range(T):
        l = data.train_loss(gamma, Us)
        
        with torch.no_grad():
            W = data.eff_w(gamma, Us)
            Us_norm = sum((torch.linalg.norm(U).item() ** 2 for U in Us)) ** 0.5

            extra = {}
            weigs = sorted(scipy.linalg.svdvals(W.cpu().numpy()), reverse=True)
            extra.update({f'weigs/{k}': weigs[k] for k in range(d)})
            if compute_eigen:
                eigs = data.hess_eigs(gamma, Us) * Us_norm ** 2
                extra.update({f'seigs/{k}': eigs[k] for k in range(3)})
            
            eeta = eta / ((1 - eta * lam) * Us_norm ** 2)
            
            wandb.log({
                'gamma': gamma.data,
                'norm/U': Us_norm,
                'norm/W': torch.linalg.norm(W),
                'loss/train': l,
                'loss/test': data.test_loss(W),
                **extra,
                'elr/si': eeta,
                'two_over_elr/si': 2 / eeta
            }, step=t, commit=(t % 1000 == 0))

            if t % 1000 == 0:
                print(f"iteration #{t}:")
                print('W:', W)
                print('M:', data.M)

        l.backward()
        
        with torch.no_grad():
            if train_gamma:
                gamma.data -= eta * gamma.grad
                gamma.grad = None

            for U in Us:
                U.data -= eta * (U.grad + lam * U.data)
                U.grad = None
            

if __name__ == '__main__':
    np.set_printoptions(precision=4, suppress=True)

    parser = argparse.ArgumentParser()
    parser.add_argument('--L', type=int, default=2)
    parser.add_argument('--d', type=int, default=50)
    parser.add_argument('--N', type=int, default=300)
    parser.add_argument('--r', type=int, default=3)
    parser.add_argument('--eta', type=float, default=0.1)
    parser.add_argument('--lam', type=float, default=0.01)
    args = parser.parse_args()

    main(N=args.N, d=args.d, r=args.r, eta=args.eta, lam=args.lam, L=args.L)