"""
Group lasso with overlap
========================

Comparison of solvers for a least squares with
overlapping group lasso regularization.

References
----------
This example is modeled after the experiments in `Adaptive Three Operator Splitting <https://arxiv.org/pdf/1804.02339.pdf>`_, Appendix E.3.
"""

import scipy.io
import copt.penalty
import numpy as np
from scipy.sparse import linalg as splinalg
import copt as cp
from optimizer_adaptos import minimize_tos
import pickle
import scipy.io
import sys
import os
import argparse

np.random.seed(1)

# choose what to do
parser = argparse.ArgumentParser(description='Experiments for Sparse Nuclear Norm Reg Experiments')
parser.add_argument('--beta', required=False, type=float, help="Regularization parameter")
parser.add_argument('--gamma0', required=False,type=float, default=0.0, help="gamma_0 for AdapTOS")
parser.add_argument('--tau', required=False, type=float, default=0.0, help="tau for PDHG and PDHG-LS")
parser.add_argument('--method', required=True, choices=["TOS", "AdapTOS", "PDHG", "PDHG-LS", "TOS-LS"], help="optimizer to use (TOS, TOS-LS, PDHG, PDHG-LS, AdapTOS)")
args = parser.parse_args()

beta = args.beta
gamma_0 = args.gamma0
method = args.method
tau = args.tau

max_iter = 10000

# .. load data ..
fh = open("data/overlapping_group_lasso.pkl", 'rb')
data = pickle.load(fh)
A, b, groups, n_features, n_samples = data.values()

# .. compute the step-size ..
s = splinalg.svds(A, k=1, return_singular_vectors=False, tol=1e-3, maxiter=500)[0]
f = copt.loss.LogLoss(A, b)
if method in ["PDHG", "PDHG-LS"]:
    step_size = 2*(1-tau)/f.lipschitz
    step_size2 = tau/step_size
else:
    step_size = 1.0 / f.lipschitz
x0 = np.zeros(n_features)

print('beta = %s' % beta)
G1 = copt.penalty.GroupL1(beta, groups[::2])
G2 = copt.penalty.GroupL1(beta, groups[1::2])


def loss(x):
    return f(x) + G1(x) + G2(x)


# run the method
if method == "PDHG-LS":
    print(f"PDHG-LS is running...")
    cb_pdhg = cp.utils.Trace()
    pdhg = cp.minimize_primal_dual(
        f.f_grad, x0, G1.prox, G2.prox,
        callback=cb_pdhg, max_iter=max_iter,
        step_size=step_size,
        step_size2=step_size2, tol=0, line_search=True)
    trace = np.array([loss(x) for x in cb_pdhg.trace_x])
    trace_time = cb_pdhg.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "PDHG":
    print(f"PDHG is running...")
    cb_pdhg_nols = cp.utils.Trace()
    pdhg_nols = cp.minimize_primal_dual(
        f.f_grad, x0, G1.prox, G2.prox,
        callback=cb_pdhg_nols, max_iter=max_iter,
        step_size=step_size,
        step_size2=step_size2, tol=0, line_search=False)
    trace = np.array([loss(x) for x in cb_pdhg_nols.trace_x])
    trace_time = cb_pdhg_nols.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS-LS":
    print(f"TOS-LS is running...")
    cb_tosls = cp.utils.Trace()
    tos_ls = cp.minimize_three_split(
        f.f_grad, x0, G1.prox, G2.prox, step_size=10 * step_size,
        max_iter=max_iter, tol=1e-14, verbose=1,
        callback=cb_tosls, h_Lipschitz=beta)
    trace = np.array([loss(x) for x in cb_tosls.trace_x])
    trace_time = cb_tosls.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS":
    print(f"TOS is running...")
    cb_tos = cp.utils.Trace()
    tos = cp.minimize_three_split(
        f.f_grad, x0, G1.prox, G2.prox,
        step_size=step_size,
        max_iter=max_iter, tol=1e-14, verbose=1,
        line_search=True, callback=cb_tos)
    trace = np.array([loss(x) for x in cb_tos.trace_x])
    trace_time = cb_tos.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "AdapTOS":
    print(f"AdapTOS is running...")
    cb_adaptos = cp.utils.Trace()
    adaptos = minimize_tos(
        f.f_grad, x0, G1.prox, G2.prox, gamma_0=gamma_0,
        max_iter=max_iter, tol=1e-14, verbose=0, adaptive=True,
        callback=cb_adaptos)
    trace = np.array([loss(x) for x in cb_adaptos.trace_x])
    trace_time = cb_adaptos.trace_time
    savedict = {'beta': beta, 'gamma_0': gamma_0, 'method': method, 'trace': trace, 'trace_time': trace_time}

else:
    print(f"Unknown method")
    sys.exit()


savename = f"runs/overlapping_group_lasso/Reg" + str(beta) + "_" + method
if method == "AdapTOS":
    savename += "_gamma0=" + str(gamma_0)

if method in ["PDHG", "PDHG-LS"]:
    savename += "_tau=" + str(tau)

if not os.path.exists('runs/overlapping_group_lasso'):
    os.makedirs('runs/overlapping_group_lasso')

with open(savename + ".pkl", 'wb') as f:
    pickle.dump(savedict, f)

scipy.io.savemat(savename + ".mat", mdict=savedict)
