# python3
"""
Total variation regularization
==============================

Comparison of solvers with total variation regularization.

References
----------
This example is modeled after the experiments in `Adaptive Three Operator Splitting <https://arxiv.org/pdf/1804.02339.pdf>`_, Appendix E.4.

"""
import numpy as np
import copt as cp
from optimizer_adaptos import minimize_tos
import pickle
import scipy.io
import os
import argparse
import sys

np.random.seed(1)

# choose what to do
parser = argparse.ArgumentParser(description='Experiments for Sparse Nuclear Norm Reg Experiments')
parser.add_argument('--beta', required=False, type=float, help="Regularization parameter")
parser.add_argument('--gamma0', required=False,type=float, default=0.0, help="gamma_0 for AdapTOS")
parser.add_argument('--tau', required=False, type=float, default=0.0, help="tau for PDHG and PDHG-LS")
parser.add_argument('--method', required=True, choices=["TOS", "AdapTOS", "PDHG", "PDHG-LS", "TOS-LS"], help="optimizer to use (TOS, TOS-LS, PDHG, PDHG-LS, AdapTOS)")
args = parser.parse_args()

beta = args.beta
gamma_0 = args.gamma0
method = args.method
tau = args.tau

max_iter = 2000

# .. load data ..
fh = open("data/tv_deblurring.pkl", 'rb')
data = pickle.load(fh)
A, b, n_features, n_samples, n_rows, n_cols = data.values()

# .. compute the step-size ..
f = cp.loss.SquareLoss(A, b)

if method in ["PDHG", "PDHG-LS"]:
    step_size = 2*(1-tau)/f.lipschitz
    step_size2 = tau/step_size
else:
    step_size = 1.0 / f.lipschitz

def loss(x, pen):
    x_mat = x.reshape((n_rows, n_cols))
    tmp1 = np.abs(np.diff(x_mat, axis=0))
    tmp2 = np.abs(np.diff(x_mat, axis=1))
    return f(x) + pen * (tmp1.sum() + tmp2.sum())


def g_prox(x, gamma, pen=beta):
    return cp.tv_prox.prox_tv1d_cols(gamma * pen, x, n_rows, n_cols)


def h_prox(x, gamma, pen=beta):
    return cp.tv_prox.prox_tv1d_rows(gamma * pen, x, n_rows, n_cols)

# run the method
if method == "PDHG-LS":
    print(f"PDHG-LS is running...")
    cb_pdhg = cp.utils.Trace()
    pdhg = cp.minimize_primal_dual(
        f.f_grad,
        np.zeros(n_features),
        g_prox,
        h_prox,
        callback=cb_pdhg,
        max_iter=max_iter,
        step_size=step_size,
        step_size2=step_size2,
        tol=0,
    )
    trace = np.array([loss(x, beta) for x in cb_pdhg.trace_x])
    trace_time = cb_pdhg.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "PDHG":
    print(f"PDHG is running...")
    cb_pdhg_nols = cp.utils.Trace()
    pdhg_nols = cp.minimize_primal_dual(
        f.f_grad,
        np.zeros(n_features),
        g_prox,
        h_prox,
        callback=cb_pdhg_nols,
        max_iter=max_iter,
        step_size=step_size,
        step_size2=step_size2,
        tol=1e-14,
        line_search=False,
    )
    trace = np.array([loss(x, beta) for x in cb_pdhg_nols.trace_x])
    trace_time = cb_pdhg_nols.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS-LS":
    print(f"TOS-LS is running...")
    cb_tosls = cp.utils.Trace()
    tos_ls = cp.minimize_three_split(
        f.f_grad,
        np.zeros(n_features),
        g_prox,
        h_prox,
        step_size=step_size,
        max_iter=max_iter,
        tol=1e-14,
        callback=cb_tosls,
        h_Lipschitz=beta,
    )
    trace = np.array([loss(x, beta) for x in cb_tosls.trace_x])
    trace_time = cb_tosls.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "TOS":
    print(f"TOS is running...")
    cb_tos = cp.utils.Trace()
    tos = cp.minimize_three_split(
        f.f_grad,
        np.zeros(n_features),
        g_prox,
        h_prox,
        step_size=step_size,
        max_iter=max_iter,
        tol=1e-14,
        callback=cb_tos,
        line_search=False,
    )
    trace = np.array([loss(x, beta) for x in cb_tos.trace_x])
    trace_time = cb_tos.trace_time
    savedict = {'beta': beta, 'method': method, 'trace': trace, 'trace_time': trace_time}

elif method == "AdapTOS":
    print(f"AdapTOS is running...")
    cb_adaptos = cp.utils.Trace()
    adaptos = minimize_tos(
        f.f_grad,
        np.zeros(n_features),
        g_prox,
        h_prox,
        gamma_0=gamma_0,
        max_iter=max_iter,
        tol=1e-14,
        verbose=0,
        callback=cb_adaptos,
        adaptive=True,
    )
    trace = np.array([loss(x, beta) for x in cb_adaptos.trace_x])
    trace_time = cb_adaptos.trace_time
    savedict = {'beta': beta, 'gamma_0': gamma_0, 'method': method, 'trace': trace, 'trace_time': trace_time}

else:
    print(f"Unknown method")
    sys.exit()


savename = f"runs/tv_deblurring/Reg" + str(beta) + "_" + method
if method == "AdapTOS":
    savename += "_gamma0=" + str(gamma_0)

if method in ["PDHG", "PDHG-LS"]:
    savename += "_tau=" + str(tau)

if not os.path.exists('runs/tv_deblurring'):
    os.makedirs('runs/tv_deblurring')

with open(savename + ".pkl", 'wb') as f:
    pickle.dump(savedict, f)

scipy.io.savemat(savename + ".mat", mdict=savedict)
