import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import ortho_group

torch.set_default_dtype(torch.float64)
plt.rcParams.update({'font.size': 12})

markers = [['v', '1', '^', '2'], 
           ['v', '1', '^', '2'], 
           ['v', '1', '^', '2']]

linestyles = ['-', ':', '-.', ':']
c = 200
epochs = 2000
mark_freq = 400

def plotGD(ax):
    def initialize(A, m, n, d, mode="column"):
        X = None
        Y = None
        if mode == "column": 
            
            Phi1 = torch.randn(n, d)    
            Phi2 = torch.randn(n, d) / np.sqrt(n)    
            X = torch.mm(A, Phi1)
            Y = torch.zeros(n, d)

        return X, Y
            
    def cond(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item() / S[r-1].item()

    def get_slim(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item(), S[r-1].item()
        
    def f(A, X, Y):
        return torch.norm(A - X @ Y.T) ** 2 / 2

    def grad(A, X, Y):
        res = X @ Y.T - A
        gx = res @ Y
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        
        return loss, gx, gy

    def gradX(A, X, Y):
        gx = (X @ Y.T - A) @ Y
        return gx

    def gradY(A, X, Y):
        res = X @ Y.T - A
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        return loss, gy

    '''
    configs
    '''
    seed = 2
    torch.manual_seed(0)
    m, n = 100, 80
    
    
    Sr = torch.tensor([1., 0.8, 0.6, 0.4, 0.1])
    
    r = Sr.size()[0]
    eps = 0


    '''
    initialization
    '''
    p = min(m, n)
    U = torch.empty(m, m)
    torch.nn.init.orthogonal_(U)
    V = torch.empty(n, n)
    torch.nn.init.orthogonal_(V)
    Se = torch.rand(p-r) * eps
    S = torch.concat([Sr, Se])
    S = torch.diagonal_scatter(torch.zeros(m, n), S)
    A = U @ S @ V.T

    colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']

    def NAG(A, X, Y, lr, beta, epochs):
        tX = X
        tY = Y
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, tX, tY)
            
            X_prev = X.clone().detach()
            Y_prev = Y.clone().detach()
            X = tX - lr * gx
            Y = tY - lr * gy
            tX = X + beta * (X - X_prev)
            tY = Y + beta * (Y - Y_prev)
            
            loss_list.append(loss)
 
        return loss_list, loss_r_list

    def AltGD(A, X, Y, lr, beta, epochs):
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, X, Y)
            

            X = X - lr * gx
            loss, gx, gy = grad(A, X, Y)
            Y = Y - lr * gy
            
            loss_list.append(loss)
  
        return loss_list, loss_r_list

    for i, d in tqdm(enumerate([5, 20, 80])):
        loss_GD = []
        loss_AltGD = []
        loss_NAG = []
        theory_GD = []
        theory_NAG = []
        for seed in tqdm(range(10), disable=True): 
            torch.manual_seed(seed)
            X, Y = initialize(A, m, n, d)
            X = X * c
            smax, smin = get_slim(X, r)
            L = smax ** 2
            mu = smin ** 2
            kappa = L / mu

            '''
            Theory
            '''
            base = f(A, X, Y)
            GD_list = [base * ((kappa-1)/(kappa+1))**(2*t) for t in range(epochs)]
            NAG_list = [base * (1-1/2/np.sqrt(kappa))**(2*t) for t in range(epochs)]
            theory_GD.append(np.array(GD_list))
            theory_NAG.append(np.array(NAG_list))
            
            '''
            NAG
            '''
            
            lr = 1/L
            tbeta = (np.sqrt(L)-np.sqrt(mu))/(np.sqrt(L)+np.sqrt(mu))
            print(lr, tbeta)
            
                
            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = NAG(A, X, Y, lr, beta, epochs)
            loss_GD.append(np.array(loss_list))

        loss_GD = np.array(loss_GD).mean(axis=0)
        theory_GD = np.array(theory_GD).mean(axis=0)

        ax.plot(loss_GD, c=colors[i], marker=markers[i][0], markevery=mark_freq, linestyle=linestyles[i], label='GD,d={}'.format(d))
        ax.plot(theory_GD, c=colors[i], marker=markers[i][1], markevery=mark_freq, linestyle=linestyles[i], label='GD(T),d={}'.format(d))

def plotNAG(ax):
    def initialize(A, m, n, d, mode="column"):
        X = None
        Y = None
        if mode == "column": 
            
            Phi1 = torch.randn(n, d)    
            Phi2 = torch.randn(n, d) / np.sqrt(n)    
            X = torch.mm(A, Phi1)
            Y = torch.zeros(n, d)

        return X, Y
            
    def cond(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item() / S[r-1].item()

    def get_slim(X, r): 
        S = torch.linalg.svdvals(X)
        return S[0].item(), S[r-1].item()
        
    def f(A, X, Y):
        return torch.norm(A - X @ Y.T) ** 2 / 2

    def grad(A, X, Y):
        res = X @ Y.T - A
        gx = res @ Y
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        
        return loss, gx, gy

    def gradX(A, X, Y):
        gx = (X @ Y.T - A) @ Y
        return gx

    def gradY(A, X, Y):
        res = X @ Y.T - A
        gy = res.T @ X
        loss = torch.norm(res) ** 2 / 2
        return loss, gy

    '''
    configs
    '''
    seed = 2
    torch.manual_seed(0)
    m, n = 100, 80
    Sr = torch.tensor([1., 0.5, 0.1, 0.01, 0.01])
    Sr = torch.tensor([1., 0.8, 0.6, 0.4, 0.01])  
    r = Sr.size()[0]
    eps = 0

    '''
    initialization
    '''
    p = min(m, n)
    U = torch.empty(m, m)
    torch.nn.init.orthogonal_(U)
    V = torch.empty(n, n)
    torch.nn.init.orthogonal_(V)
    Se = torch.rand(p-r) * eps
    S = torch.concat([Sr, Se])
    S = torch.diagonal_scatter(torch.zeros(m, n), S)
    A = U @ S @ V.T

    colors = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']

    def NAG(A, X, Y, lr, beta, epochs):
        tX = X
        tY = Y
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, tX, tY)
            
            X_prev = X.clone().detach()
            Y_prev = Y.clone().detach()
            X = tX - lr * gx
            Y = tY - lr * gy
            tX = X + beta * (X - X_prev)
            tY = Y + beta * (Y - Y_prev)
            
            loss_list.append(loss)

        return loss_list, loss_r_list

    def AltGD(A, X, Y, lr, beta, epochs):
        loss_list = []
        loss_r_list = []
        for epoch in tqdm(range(epochs), disable=True):
            loss, gx, gy = grad(A, X, Y)
            

            X = X - lr * gx
            loss, gx, gy = grad(A, X, Y)
            Y = Y - lr * gy
            
            loss_list.append(loss)

        return loss_list, loss_r_list

    for i, d in tqdm(enumerate([5, 20, 80])):
        loss_GD = []
        loss_AltGD = []
        loss_NAG = []
        theory_GD = []
        theory_NAG = []
        for seed in tqdm(range(10), disable=True): 
            torch.manual_seed(seed)
            X, Y = initialize(A, m, n, d)
            X = X * c
            smax, smin = get_slim(X, r)
            L = smax ** 2
            mu = smin ** 2
            kappa = L / mu

            '''
            Theory
            '''
            base = f(A, X, Y)
            GD_list = [base * ((kappa-1)/(kappa+1))**(2*t) for t in range(epochs)]
            NAG_list = [base * (1-1/2/np.sqrt(kappa))**(2*t) for t in range(epochs)]
            theory_GD.append(np.array(GD_list))
            theory_NAG.append(np.array(NAG_list))
            
            '''
            NAG
            '''
            
            lr = 1/L
            tbeta = (np.sqrt(L)-np.sqrt(mu))/(np.sqrt(L)+np.sqrt(mu))
            print(lr, tbeta)
            
                
            lr = 2/(L+mu)
            beta = 0.
            loss_list, loss_r_list = NAG(A, X, Y, lr, beta, epochs)
            loss_GD.append(np.array(loss_list))

            lr = 1/L
            beta = tbeta
            loss_list, loss_r_list = NAG(A, X, Y, lr, beta, epochs)
            loss_NAG.append(np.array(loss_list))

        loss_GD = np.array(loss_GD).mean(axis=0)
        
        theory_NAG = np.array(theory_NAG).mean(axis=0)
        
        loss_NAG = np.array(loss_NAG).mean(axis=0)
        
        ax.plot(loss_GD, c=colors[i], marker=markers[i][0], markevery=mark_freq, linestyle=linestyles[i], label='GD,d={}'.format(d))
        ax.plot(loss_NAG, c=colors[i], marker=markers[i][2], markevery=mark_freq, linestyle=linestyles[i], label='NAG,d={}'.format(d))
        ax.plot(theory_NAG, c=colors[i], marker=markers[i][3], markevery=mark_freq, linestyle=linestyles[i], label='NAG(T),d={}'.format(d))


fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
fig.set_figwidth(7.2)
fig.set_figheight(4.0)

plotGD(ax1)
plotNAG(ax2)
ax1.title.set_text(r"$\kappa$=10")
ax1.set_xlabel('iteration')
ax2.title.set_text(r"$\kappa$=100")
ax2.set_xlabel('iteration')
ax1.set_ylabel('loss')
ax1.set_yscale('log')
ax2.set_yscale('log')
box = ax2.get_position()

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]

unique_entries = {}
for line, label in zip(lines, labels):
    if label not in unique_entries:
        unique_entries[label] = line

unique_lines = list(unique_entries.values())
unique_labels = list(unique_entries.keys())

def custom_sort_key(item):
    label = item[0]
    order = ['GD', 'GD(T)', 'NAG', 'NAG(T)']
    id = int(label.split('=')[1]) * 100 + order.index(label.split(',')[0])
    return id


sorted_items = sorted(unique_entries.items(), key=custom_sort_key)
sorted_dict = dict(sorted_items)
lines = list(sorted_dict.values())
labels = list(sorted_dict.keys())
ax2.legend(lines, labels, loc='center left', bbox_to_anchor=(1, 0.5), numpoints=2)
fig.tight_layout()
plt.savefig(f"figgen/results/MF_c{c}_theory_all.pdf", dpi=100)
    
