import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import ortho_group
torch.set_default_dtype(torch.float64)
plt.rcParams.update({'font.size': 12})
markers = [['v', '^', 'o'], 
           ['v', '^', 'o'], 
           ['v', '^', 'o']]
linestyles = ['-', ':', '-.', ':']
epochs = 200
mark_freq = 40
c = 50
def plotMF(ax): 
    def initialize(A, m, n, d, mode="column"):
        X = None
        Y = None
        if mode == "column": 
            
            Phi1 = torch.randn(n, d)    
            Phi2 = torch.randn(n, d) / np.sqrt(n)    
            X = torch.mm(A, Phi1)
            Y = torch.zeros(n, d)
        return X, Y
            
    def cond(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item() / S[r-1].item()
    def get_slim(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item(), S[r-1].item()
        
    def f(A, X, Y):
        return torch.norm(A - X @ Y.T) ** 2 / 2
    def grad(A, X, Y):
        res = X @ Y.T - A
        gx = res @ Y
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        
        return loss, gx, gy
    def gradX(A, X, Y):
        gx = (X @ Y.T - A) @ Y
        return gx
    def gradY(A, X, Y):
        res = X @ Y.T - A
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        return loss, gy

    '''
    configs
    '''
    seed = 2
    torch.manual_seed(0)
    m, n = 100, 80
    
    
    Sr = torch.tensor([1., 0.8, 0.6, 0.4, 0.2])
    
    r = Sr.size()[0]
    eps = 0
    c = 50
    '''
    initialization
    '''
    p = min(m, n)
    U = torch.empty(m, m)
    torch.nn.init.orthogonal_(U)
    V = torch.empty(n, n)
    torch.nn.init.orthogonal_(V)
    Se = torch.rand(p-r) * eps
    S = torch.concat([Sr, Se])
    S = torch.diagonal_scatter(torch.zeros(m, n), S)
    A = U @ S @ V.T
    
    colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
    def NAG(A, X, Y, lr, beta, epochs):
        tX = X
        tY = Y
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, tX, tY)
            
            X_prev = X.clone().detach()
            Y_prev = Y.clone().detach()
            X = tX - lr * gx
            Y = tY - lr * gy
            tX = X + beta * (X - X_prev)
            tY = Y + beta * (Y - Y_prev)
            
            loss_list.append(loss) 
        return loss_list, loss_r_list
    
    def AltGD(A, X, Y, lr, beta, epochs):
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, X, Y)
            
            X = X - lr * gx
            loss, gx, gy = grad(A, X, Y)
            Y = Y - lr * gy
            
            loss_list.append(loss)
            
        return loss_list, loss_r_list
    
    for i, d in tqdm(enumerate([5, 20, 80])):
    
        loss_GD = []
        loss_GD_diff = []
        loss_AltGD = []
        loss_NAG = []
        for seed in tqdm(range(10), disable=True): 
            torch.manual_seed(seed)
            X, Y = initialize(A, m, n, d)
            X = X * c
            smax, smin = get_slim(X, r)
            L = smax ** 2
            mu = smin ** 2
            
            '''
            NAG
            '''
            lr = 1/L
            tbeta = (np.sqrt(L)-np.sqrt(mu))/(np.sqrt(L)+np.sqrt(mu))
            print(lr, tbeta)
            
            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = NAG(A, X, Y, lr, beta, epochs)
            loss_GD.append(np.array(loss_list))

            lr = 1/L
            beta = 0.
            loss_list, loss_r_list = NAG(A, X, Y, lr, beta, epochs)
            loss_GD_diff.append(np.array(loss_list))
            
            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = AltGD(A, X, Y, lr, beta, epochs)
            loss_AltGD.append(np.array(loss_list))
            
        loss_GD = np.array(loss_GD).mean(axis=0)
        loss_AltGD = np.array(loss_AltGD).mean(axis=0)
        
        loss_GD_diff = np.array(loss_GD_diff).mean(axis=0)
        
        ax.plot(loss_GD, c=colors[i], marker=markers[i][0], markevery=40, linestyle=linestyles[i], label='GD,d={}'.format(d))
        ax.plot(loss_AltGD, c=colors[i], marker=markers[i][1], markevery=40, linestyle=linestyles[i], label='AltGD,d={}'.format(d))
        
        ax.plot(loss_GD_diff, c=colors[i], marker=markers[i][2], markevery=40, linestyle=linestyles[i], label='GD(1/L),d={}'.format(d))
    
def plotLNN(ax):
    def initialize(L, m, n, N, d, mode="column"):
        X = None
        Y = None
        if mode == "column": 
            
            Phi1 = torch.randn(N, d)    
            
            X = torch.mm(L, Phi1)
            Y = torch.zeros(n, d)
        return X, Y
            
    def cond(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item() / S[r-1].item()
    def get_slim(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item(), S[r-1].item()
        
    def grad(L, D, X, Y):
        res = X @ Y.T @ D - L
        gx = res @ D.T @ Y
        gy = D @ res.T @ X
        loss = torch.norm(res) ** 2 / 2
        
        return loss, gx, gy
            
    
    
    '''
    configs
    '''
    seed = 2
    torch.manual_seed(0)
    m, n, N = 100, 80, 120
    
    
    Sr = torch.tensor([1., 0.9, 0.8, 0.7, 0.5])
    
    
    r = Sr.size()[0]
    eps = 0
    c = 50
    '''
    initialization
    '''
    p = min(n, N)
    U = torch.empty(n, n)
    torch.nn.init.orthogonal_(U)
    V = torch.empty(N, N)
    torch.nn.init.normal_(V)
    Se = torch.rand(p-r) * eps
    S = torch.concat([Sr, Se])
    S = torch.diagonal_scatter(torch.zeros(n, N), S)
    D = U @ S @ V.T
    
    A = torch.randn(m, n)
    label = A @ D
    lmax, lmin = get_slim(D, r)
    
    colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
    def NAG(L, D, X, Y, lr, beta, epochs):
        tX = X
        tY = Y
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(L, D, tX, tY)
            
            X_prev = X.clone().detach()
            Y_prev = Y.clone().detach()
            X = tX - lr * gx
            Y = tY - lr * gy
            tX = X + beta * (X - X_prev)
            tY = Y + beta * (Y - Y_prev)
            
            loss_list.append(loss)     
            
        return loss_list, loss_r_list
    def AltGD(L, D, X, Y, lr, beta, epochs):
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(L, D, X, Y)
            
            X = X - lr * gx
            loss, gx, gy = grad(L, D, X, Y)
            Y = Y - lr * gy
            
            loss_list.append(loss)

        return loss_list, loss_r_list
    
    for i, d in tqdm(enumerate([5, 20, 80])):
        loss_GD = []
        loss_GD_diff = []
        loss_AltGD = []
        loss_NAG = []
        for seed in tqdm(range(10), disable=True): 
            torch.manual_seed(seed)
            X, Y = initialize(label, m, n, N, d)
            X = X * c
            smax, smin = get_slim(X, r)
            L = (smax ** 2) * (lmax**2)
            mu = (smin ** 2) * (lmin**2)

            
            '''
            NAG
            '''
            
            lr = 1/L
            tbeta = (np.sqrt(L)-np.sqrt(mu))/(np.sqrt(L)+np.sqrt(mu))
            # print(lr, tbeta)
            
            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = NAG(label, D, X, Y, lr, beta, epochs)
            loss_GD.append(np.array(loss_list))

            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = AltGD(label, D, X, Y, lr, beta, epochs)
            loss_AltGD.append(np.array(loss_list))
            
            lr = 1/L
            beta = 0.
            loss_list, loss_r_list = NAG(label, D, X, Y, lr, beta, epochs)
            loss_GD_diff.append(np.array(loss_list))
            
        loss_GD = np.array(loss_GD).mean(axis=0)
        loss_AltGD = np.array(loss_AltGD).mean(axis=0)
        loss_GD_diff = np.array(loss_GD_diff).mean(axis=0)
        ax.plot(loss_GD, c=colors[i], marker=markers[i][0], markevery=40, linestyle=linestyles[i], label='GD,d={}'.format(d))
        ax.plot(loss_AltGD, c=colors[i], marker=markers[i][1], markevery=40, linestyle=linestyles[i], label='AltGD,d={}'.format(d))
        ax.plot(loss_GD_diff, c=colors[i], marker=markers[i][2], markevery=40, linestyle=linestyles[i], label='GD(1/L),d={}'.format(d))

fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
fig.set_figwidth(7.2)
fig.set_figheight(4.0)
plotMF(ax1)
plotLNN(ax2)
ax1.title.set_text('matrix factorization')
ax1.set_xlabel('iteration')
ax2.title.set_text('linear network')
ax2.set_xlabel('iteration')
ax1.set_ylabel('loss')
ax1.set_yscale('log')
ax2.set_yscale('log')
box = ax2.get_position()
ax2.legend(loc='center left', bbox_to_anchor=(1, 0.5), numpoints=2)
fig.tight_layout()
plt.savefig(f"figgen/results/MF_c{c}_Alt_all.pdf", dpi=100, pad_inches=0.0)
    
