#!/usr/bin/env python
# coding: utf-8

# ### Import libraries and define problem parameters

# In[1]:


import math
import torch
import random
from scipy.stats import ortho_group
from matplotlib import pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from tqdm import trange

get_ipython().run_line_magic('matplotlib', 'inline')
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'Helvetica'
mpl.rcParams['font.size'] = 12
mpl.rcParams['figure.dpi'] = 350
mpl.rcParams["text.usetex"]
colors = ['C1', 'C6', 'C2', 'C9', 'C10', 'C3', 'C4', 'C7', 'C8', 'C5']
markers = [r'$\bigcirc$', r'$\boxdot$', r'$\bigtriangleup$', r'$\heartsuit$', r'$\diamondsuit$', 'v']
# markers = ['s', 'o', 'D', 'v', '^', '*', '1', '2', '3', '4']

# For reproducibility
# seed = 37492347
# random.seed(seed)          # control python seed
# np.random.seed(seed)       # control numpy seed
# torch.manual_seed(seed);   # control pytorch seed


# ### Define the Spectral Neural Networks and helper functions

# In[2]:


def dot(A, X):
    m = A.shape[0]
    X = X.repeat(m, 1, 1)
    return torch.sum(A * X, dim=[1, 2])

def rec_diag_matrix(a, m, n):
    # Return an m x n matrix with the vector a as the diagonal elements
    Q = torch.zeros(m, n)
    Q[:m, :m] = torch.diagflat(a)
    return Q

def spectral_init(n, Phi, PsiT, epsilon=1): # Assumption 2 in our paper
    G = torch.tensor(ortho_group.rvs(n), dtype=torch.float)
    Ubar, _ = torch.sort(torch.abs(torch.randn(n) * epsilon))
    Vbar, _ = torch.sort(torch.abs(torch.randn(n) * epsilon))
    U_init = Phi @ torch.diag(Ubar) @ G
    V_init = PsiT.T @ torch.diag(Vbar) @ G
    return U_init, V_init

def random_init(n, nlow=None):
    if nlow is None:
        nlow = n
    U_init = torch.randn(n, nlow)
    V_init = torch.randn(n, nlow)
    return U_init, V_init

def identity_init(n, eps=1e-2):
    U_init = torch.eye(n) * eps
    V_init = torch.eye(n) * eps
    return U_init, V_init

def spectral_nonlinear(Q):
    # Apply a non-linear on the spectrum of Q
    U, S, Vt = torch.linalg.svd(Q, full_matrices=False)
    S = torch.diag_embed(torch.sigmoid(S))
    return torch.bmm(torch.bmm(U, S), Vt)

def mse_loss(X, A, y):
    return torch.mean(torch.square(y - dot(A, X)))

def test_error(X, Xtrue):
    return torch.norm(X - Xtrue, p='fro') / torch.norm(Xtrue, p='fro')
    

class SNN(torch.nn.Module):
    def __init__(self, n, hidden_dims=[1, 1], Phi=None, PsiT=None):
        super(SNN, self).__init__()
        self.in_dim = hidden_dims[0]
        self.n_layers = len(hidden_dims) - 1
        self.Us = torch.nn.ParameterList()
        self.Vs = torch.nn.ParameterList()
        self.alphas = torch.nn.ParameterList()
        # Initialize Us and Vs
        for i in range(self.in_dim):
            if Phi is not None and PsiT is not None:
                U, V = spectral_init(n, Phi, PsiT, epsilon=1.)
            else:
                U, V = random_init(n, n)
            U = torch.nn.Parameter(U)
            V = torch.nn.Parameter(V)
            self.Us.append(U)
            self.Vs.append(V)
        # Initialize alphas
        for i in range(1, self.n_layers + 1):
            alpha_i = torch.nn.Parameter(torch.randn(hidden_dims[i-1], hidden_dims[i]))
            self.alphas.append(alpha_i)
        
    def forward(self):
        X = [U @ V.T for U, V in zip(self.Us, self.Vs)]
        X = torch.stack(X, axis=0)
        for i in range(self.n_layers):
            X = spectral_nonlinear(X)
            X = self.alphas[i].T @ X.reshape(-1, n * n)
            X = X.reshape(-1, n, n)
        return X.reshape(n, n)

    def sv(self):
        Q = self.forward()
        return torch.diag(Phi.T @ Q @ PsiT.T)

    def nuclear_norm(self):
        return torch.sum(self.sv())


# In[3]:


def generate_data(m, n, nlow, well_specified=True):
    ''' Inputs:
        - m: number of measurement matrices
        - n: dimension of the true matrix
        - nlow: rank of the true matrix
    '''
    U0 = torch.randn(n, n)
    A0 = U0 @ U0.T
    Phi, S0, PsiT = torch.linalg.svd(A0)
    A = []
    if well_specified:
        # Xtrue = torch.abs(torch.randn(n)) * 10
        # Xtrue, _ = torch.sort(Xtrue, descending=True)
        # Xtrue = Phi @ torch.diag(Xtrue) @ PsiT
        Xtrue = torch.rand(n, nlow) @ torch.rand(nlow, n)
        for i in range(2 * m):
            Si = torch.abs(torch.randn(n))
            Si, _ = torch.sort(Si, descending=True)
            if i == 0:
                Si = torch.ones(n)
            Ai = Phi @ torch.diag(Si) @ PsiT
            A.append(Ai)
    else:
        Xtrue = torch.rand(n, nlow) @ torch.rand(nlow, n)
        for i in range(2 * m):
            Ai = torch.rand(n, n) * 10
            A.append(Ai)
    
    A = torch.stack(A)
    y = dot(A, Xtrue)

    A_train = A[:m]
    y_train = y[:m]
    A_test  = A[m:]
    y_test  = y[m:]

    return Xtrue, A_train, y_train, A_test, y_test, Phi, PsiT


# In[4]:


def run_gradient_descent(model, A_train, y_train, A_test, y_test, lr=1e-4, num_iter=50000, freq=5):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    singular_value_traces = []
    nuclear_norm_traces = []
    train_error_traces = []
    test_error_traces = []

    with trange(num_iter, unit='epochs') as pbar:
        for i in pbar:
            if i == 0 or (i+1) % freq == 0:
                singular_value_traces.append(model.sv())
                nuclear_norm_traces.append(model.nuclear_norm())
                train_error_traces.append(mse_loss(model(), A_train, y_train))
                test_error_traces.append(mse_loss(model(), A_test, y_test))
            optimizer.zero_grad()
            loss = 0.5 * torch.sum(torch.square(y_train - dot(A_train, model())))
            loss.backward()
            optimizer.step()
            pbar.set_description(f'Training loss: {loss.item()}')

    singular_value_traces = torch.stack(singular_value_traces).detach().numpy().T
    nuclear_norm_traces = torch.stack(nuclear_norm_traces).detach().numpy().T
    train_error_traces = torch.stack(train_error_traces).detach().numpy()
    test_error_traces = torch.stack(test_error_traces).detach().numpy()

    traces = {
        'singular_value_traces': singular_value_traces,
        'nuclear_norm_traces': nuclear_norm_traces,
        'train_error_traces': train_error_traces,
        'test_error_traces': test_error_traces,
    }
    return traces


# In[5]:


def plot_sv(traces):
    singular_value_traces = traces['singular_value_traces']
    size = singular_value_traces.shape[1]
    Xtrue_bar = torch.diag(Phi.T @ Xtrue @ PsiT.T)

    iter_traces = np.arange(size) * freq

    fig, axes = plt.subplots(math.ceil(n / 4), 4, figsize=(4 * 4, 4 * math.ceil(n / 4)))
    axes = axes.ravel()
    for i in range(n):
        axes[i].plot(iter_traces[::50], singular_value_traces[i,::50], label='X singular values')
        axes[i].axhline(y=Xtrue_bar[i], linestyle='dashed', label='$X^*$ singular values')
        axes[i].legend()

    return fig


# In[6]:


def plot_nuclear_norm(traces):
    nuclear_norm_traces = traces['nuclear_norm_traces']
    size = nuclear_norm_traces.shape[1]

    iter_traces = np.arange(size) * freq

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    ax.plot(iter_traces[::50], nuclear_norm_traces[i,::50], label='X nuclear norm')
    ax.axhline(y=Xtrue_bar[i], linestyle='dashed', label='$X^*$ nuclear norm')
    ax.legend()

    return fig


# ### Well-specified setting
# - The measurement matrices shared the same left and right singular vectors $\Phi$ and $\Psi$.
# - The true matrix $X^*$ has $\Phi$ and $\Psi$ as its left and right singular vectors.
# 
# ### Misspecified setting
# - The above two assumptions no longer hold.
# - The measurement matrices are randomly generated.
# - The true matrix $X^*$ is random and low-rank.

# In[8]:


m, n = 60, 10 # number of measurement
s = 3         # number of layers
nlow = 5      # low-rank true matrix

num_iter = 200000
freq = 25
lrs = [5e-4, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6]
results = {}

Xtrue, A_train, y_train, A_test, y_test, Phi, PsiT = generate_data(m, n, nlow, well_specified=True)
for lr in lrs:
    model = SNN(n, hidden_dims=[5, 3, 1], Phi=Phi, PsiT=PsiT)
    traces = run_gradient_descent(model, A_train, y_train, A_test, y_test, lr=lr, num_iter=num_iter, freq=freq)
    results[lr] = traces


# In[9]:


fig, axes = plt.subplots(nrows=1, ncols=3)
axes = axes.ravel()

for i, lr in enumerate(lrs):
    nuclear_norm_traces = results[lr]['nuclear_norm_traces']
    train_error_traces = results[lr]['train_error_traces']
    test_error_traces = results[lr]['test_error_traces']
    Xtrue_nuclear_norm = torch.norm(Xtrue, p='nuc')

    size = nuclear_norm_traces.shape[0]
    iter_traces = np.arange(size) * freq
    num_markers = 25
    step_size = size  // num_markers

    axes[0].plot(iter_traces[160::step_size], np.log(nuclear_norm_traces[160::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)
    axes[1].plot(iter_traces[160::step_size], np.log(train_error_traces[160::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)
    axes[2].plot(iter_traces[160::step_size], np.log(test_error_traces[160::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)

    axes[0].set_ylabel('Nuclear Norm (log scale)')
    axes[1].set_ylabel('Train Loss (log scale)')
    axes[2].set_ylabel('Test Loss (log scale)')
    axes[0].set_xlabel('Iterations')
    axes[1].set_xlabel('Iterations')
    axes[2].set_xlabel('Iterations')
    # axes[1].set_yscale('log')
    # axes[2].set_yscale('log')

    axes[0].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[1].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[2].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[0].yaxis.set_major_locator(plt.MultipleLocator(0.1))
    axes[1].yaxis.set_major_locator(plt.MultipleLocator(5))
    axes[2].yaxis.set_major_locator(plt.MultipleLocator(5))

labels = [r'$5 \times 10^{-4}$', r'$1 \times 10^{-4}$', r'$5 \times 10^{-5}$', r'$1 \times 10^{-5}$', r'$5 \times 10^{-6}$', r'$1 \times 10^{-6}$']
handles = [mlines.Line2D([], [], color=colors[i], marker=markers[i], linestyle='dashed', mew=0.7) for i in range(len(labels))]
fig_leg = plt.figlegend(
    labels=labels,
    # loc='lower center',
    loc=(0.21, 0.04),
    # loc='best',
    handles=handles,
    ncol=6,
    handletextpad=0.3,
    columnspacing=1,
    handlelength=2.0
)
fig_leg.get_frame().set_edgecolor('black')

fig.tight_layout(rect=[0, 0.15, 1, 0.99], h_pad=0, w_pad=-2)
fig.set_figheight(4)
fig.set_figwidth(14)
# fig.suptitle("Misspecified Setting")
plt.savefig(f'plot.pdf', bbox_inches='tight')


# In[12]:


m, n = 60, 10 # number of measurement
s = 3         # number of layers
nlow = 5      # low-rank true matrix

num_iter = 200000
freq = 25
lr = 5e-6
nets = [[1, 1], [5, 1], [10, 1], [20, 1]]
results = []

Xtrue, A_train, y_train, A_test, y_test, Phi, PsiT = generate_data(m, n, nlow, well_specified=True)
for net in nets:
    model = SNN(n, hidden_dims=net, Phi=Phi, PsiT=PsiT)
    traces = run_gradient_descent(model, A_train, y_train, A_test, y_test, lr=lr, num_iter=num_iter, freq=freq)
    results.append(traces)


# In[11]:


fig, axes = plt.subplots(nrows=1, ncols=3)
axes = axes.ravel()

for i in range(len(nets)):
    nuclear_norm_traces = results[i]['nuclear_norm_traces']
    train_error_traces = results[i]['train_error_traces']
    test_error_traces = results[i]['test_error_traces']
    Xtrue_nuclear_norm = torch.norm(Xtrue, p='nuc')

    size = nuclear_norm_traces.shape[0]
    iter_traces = np.arange(size) * freq
    num_markers = 25
    step_size = size  // num_markers

    axes[0].plot(iter_traces[30::step_size], np.log(nuclear_norm_traces[30::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)
    axes[1].plot(iter_traces[30::step_size], np.log(train_error_traces[30::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)
    axes[2].plot(iter_traces[30::step_size], np.log(test_error_traces[30::step_size]), marker=markers[i], color=colors[i], linestyle='dashed', mew=0.7)

    axes[0].set_ylabel('Nuclear Norm (log scale)')
    axes[1].set_ylabel('Train Loss (log scale)')
    axes[2].set_ylabel('Test Loss (log scale)')
    axes[0].set_xlabel('Iterations')
    axes[1].set_xlabel('Iterations')
    axes[2].set_xlabel('Iterations')
    # axes[1].set_yscale('log')
    # axes[2].set_yscale('log')

    axes[0].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[1].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[2].xaxis.set_major_locator(plt.MultipleLocator(100000))
    axes[0].yaxis.set_major_locator(plt.MultipleLocator(1))
    axes[1].yaxis.set_major_locator(plt.MultipleLocator(5))
    axes[2].yaxis.set_major_locator(plt.MultipleLocator(5))

labels = [r'$1$-block', r'$1$-layer', r'$2$-layer', r'$3$-layer']
handles = [mlines.Line2D([], [], color=colors[i], marker=markers[i], linestyle='dashed', mew=0.7) for i in range(len(labels))]
fig_leg = plt.figlegend(
    labels=labels,
    # loc='lower center',
    loc=(0.33, 0.04),
    # loc='best',
    handles=handles,
    ncol=4,
    handletextpad=0.3,
    columnspacing=1,
    handlelength=2.0
)
fig_leg.get_frame().set_edgecolor('black')

fig.tight_layout(rect=[0, 0.15, 1, 0.99], h_pad=0, w_pad=-2)
fig.set_figheight(4)
fig.set_figwidth(14)
# fig.suptitle("Misspecified Setting")
plt.savefig(f'plot.pdf', bbox_inches='tight')