"""
Group lasso with overlap
========================

Comparison of solvers for a hinge-loss with
overlapping group lasso regularization.

References
----------
This example is modeled after the experiments in `Adaptive Three Operator Splitting <https://arxiv.org/pdf/1804.02339.pdf>`_, Appendix E.3.
"""

import scipy.io
import copt.penalty
from load_realsim import load_realsim
import numpy as np
from scipy.sparse import linalg as splinalg
import copt as cp
from optimizer_adaptos import minimize_tos
import pickle
import scipy.io
import sys
import os
import argparse
from copt_hingeloss import HingeLoss

np.random.seed(1)

# choose what to do
parser = argparse.ArgumentParser(description='Experiments for Overlapping Group Lasso with Hinge Loss')
parser.add_argument('--beta', required=False, type=float, help="Regularization parameter")
parser.add_argument('--gamma0', required=False,type=float, default=0.0, help="gamma_0 for AdapTOS")
parser.add_argument('--method', required=True, choices=["AdapTOS"], help="optimizer to use (AdapTOS)")
args = parser.parse_args()

beta = args.beta
gamma_0 = args.gamma0
method = args.method

# .. load data ..
from copt.datasets import load_rcv1
A, b = load_rcv1()
max_iter = 1000
n_samples, n_features = A.shape

group_max = (n_features - 10) // 8 + 1
groups = [np.arange(8 * i, 8 * i + 10) for i in range(group_max)]

# .. compute the step-size ..
f = HingeLoss(A, b, False)
step_size = 1.0
x0 = np.zeros(n_features)

print('beta = %s' % beta)
G1 = copt.penalty.GroupL1(beta, groups[::2])
G2 = copt.penalty.GroupL1(beta, groups[1::2])


def loss(x):
    return f(x) + G1(x) + G2(x)

print(f"AdapTOS is running...")
cb_adaptos = cp.utils.Trace()
adaptos = minimize_tos(
    f.f_grad, x0, G1.prox, G2.prox, gamma_0=gamma_0,
    max_iter=max_iter, tol=1e-14, verbose=0, adaptive=True,
    callback=cb_adaptos)
trace = np.array([loss(x) for x in cb_adaptos.trace_x])
trace_time = cb_adaptos.trace_time
savedict = {'beta': beta, 'gamma_0': gamma_0, 'method': method, 'trace': trace, 'trace_time': trace_time}

savename = f"runs/HingeLoss/Reg" + str(beta) + "_" + method + "_gamma0=" + str(gamma_0)

if not os.path.exists('runs/HingeLoss'):
    os.makedirs('runs/HingeLoss')

with open(savename + ".pkl", 'wb') as f:
    pickle.dump(savedict, f)

scipy.io.savemat(savename + ".mat", mdict=savedict)
