"""
Estimating a sparse and low rank matrix
=======================================

References
----------
This example is modeled after the experiments in `Adaptive Three Operator Splitting <https://arxiv.org/pdf/1804.02339.pdf>`_, Appendix E.5.

"""
import copt.loss
import copt.penalty

import numpy as np
from scipy.sparse import linalg as splinalg
import copt as cp
from optimizer_adaptos import minimize_tos
import pickle
import scipy.io
import sys
import os
import argparse

np.random.seed(1)

# choose what to do
parser = argparse.ArgumentParser(description='Experiments for Sparse Nuclear Norm Reg Experiments')
parser.add_argument('--beta', required=False, type=float, help="Regularization parameter")
parser.add_argument('--gamma0', required=False,type=float, default=0.0, help="gamma_0 for AdapTOS")
parser.add_argument('--tau', required=False, type=float, default=0.0, help="tau for PDHG and PDHG-LS")
parser.add_argument('--method', required=True, choices=["TOS", "AdapTOS", "PDHG", "PDHG-LS", "TOS-LS"], help="optimizer to use (TOS, TOS-LS, PDHG, PDHG-LS, AdapTOS)")
args = parser.parse_args()

beta = args.beta
gamma_0 = args.gamma0
method = args.method
tau = args.tau

max_iter = 100000

# .. load data ..
fh = open("data/sparse_nuclear_norm.pkl", 'rb')
data = pickle.load(fh)
A, b, Sigma, n_features, n_samples = data.values()

# .. compute the step-size ..
s = splinalg.svds(A, k=1, return_singular_vectors=False, tol=1e-3, maxiter=500)[0]
f = copt.loss.HuberLoss(A, b)
if method in ["PDHG", "PDHG-LS"]:
    step_size = 2*(1-tau)/f.lipschitz
    step_size2 = tau/step_size
else:
    step_size = 1.0 / f.lipschitz

x0 = np.zeros(n_features)

print("beta = %s" % beta)
G1 = copt.penalty.TraceNorm(beta, Sigma.shape)
G2 = copt.penalty.L1Norm(beta)


def loss(x):
    return f(x) + G1(x) + G2(x)


# run the method
if method == "PDHG-LS":
    print(f"PDHG-LS is running...")
    cb_pdhg = cp.utils.Trace()
    pdhg = cp.minimize_primal_dual(
        f.f_grad,
        x0,
        G1.prox,
        G2.prox,
        step_size=step_size,
        step_size2=step_size2,
        max_iter=max_iter,
        tol=1e-14,
        verbose=0,
        line_search=True,
        callback=cb_pdhg,
    )
    trace = np.array([loss(x) for x in cb_pdhg.trace_x])
    trace_time = cb_pdhg.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "PDHG":
    print(f"PDHG is running...")
    cb_pdhg_nols = cp.utils.Trace()
    pdhg_nols = cp.minimize_primal_dual(
        f.f_grad,
        x0,
        G1.prox,
        G2.prox,
        step_size=step_size,
        step_size2=step_size2,
        max_iter=max_iter,
        tol=1e-14,
        verbose=1,
        line_search=False,
        callback=cb_pdhg_nols,
    )
    trace = np.array([loss(x) for x in cb_pdhg_nols.trace_x])
    trace_time = cb_pdhg_nols.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS-LS":
    print(f"TOS-LS is running...")
    cb_tosls = cp.utils.Trace()
    tos_ls = cp.minimize_three_split(
        f.f_grad,
        x0,
        G2.prox,
        G1.prox,
        step_size=5 * step_size,
        max_iter=max_iter,
        tol=1e-14,
        verbose=1,
        callback=cb_tosls,
        h_Lipschitz=beta,
    )
    trace = np.array([loss(x) for x in cb_tosls.trace_x])
    trace_time = cb_tosls.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS":
    print(f"TOS is running...")
    cb_tos = cp.utils.Trace()
    tos = cp.minimize_three_split(
        f.f_grad,
        x0,
        G1.prox,
        G2.prox,
        step_size=step_size,
        max_iter=max_iter,
        tol=1e-14,
        verbose=1,
        line_search=False,
        callback=cb_tos,
    )
    trace = np.array([loss(x) for x in cb_tos.trace_x])
    trace_time = cb_tos.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "AdapTOS":
    print(f"AdapTOS is running...")
    cb_adaptos = cp.utils.Trace()
    adaptos = minimize_tos(
        f.f_grad,
        x0,
        G2.prox,
        G1.prox,
        gamma_0=gamma_0,
        max_iter=max_iter,
        tol=1e-14,
        verbose=0,
        callback=cb_adaptos,
        adaptive=True,
    )
    trace = np.array([loss(x) for x in cb_adaptos.trace_x])
    trace_time = cb_adaptos.trace_time
    savedict = {'beta': beta, 'gamma_0': gamma_0, 'method': method, 'trace': trace, 'trace_time': trace_time}

else:
    print(f"Unknown method")
    sys.exit()


savename = f"runs/sparse_nuclear_norm/Reg" + str(beta) + "_" + method
if method == "AdapTOS":
    savename += "_gamma0=" + str(gamma_0)

if method in ["PDHG", "PDHG-LS"]:
    savename += "_tau=" + str(tau)
    
if not os.path.exists('runs/sparse_nuclear_norm'):
    os.makedirs('runs/sparse_nuclear_norm')

with open(savename + ".pkl", 'wb') as f:
    pickle.dump(savedict, f)

scipy.io.savemat(savename + ".mat", mdict=savedict)
