"""
This figures explains why some convergence curves are not proper curves in the
paper, i.e. the solver seems to ''go back in time''. For that we highlight
timing variability when one only has access to a solver as a black box,
controlling its maximum number of iteration.
"""
import time
import warnings
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import Lasso


def lasso_loss(X, y, beta, lmbda):
    return norm(y - X @ beta) ** 2 / (2 * len(y)) + lmbda * norm(beta, ord=1)


np.random.seed(0)
X = np.random.randn(400, 500)
y = np.random.randn(400)


params = dict()

######################## sklearn ##############################################
warnings.filterwarnings("ignore", category=ConvergenceWarning)
sklearn_loss = []
sklearn_time = []
clf = Lasso(alpha=0.01, fit_intercept=False, tol=1e-10, warm_start=False)

# We can't launch the solver and monitor times and losses along one run,
# because we can't access inside sklearn code. Instead, we have to run the
# solver for 1 iteration, then 2 (starting from 0 again), then 3, etc.
# because there is variability in code execution, it MAY HAPPEN that running
# the solver with 5 iterations takes more times than running it with 6,
# hence producing a curve going backward in times:
for max_iter in range(1, 100):
    t0 = time.time()
    clf.max_iter = max_iter
    clf.fit(X, y)
    sklearn_time.append(time.time() - t0)
    sklearn_loss.append(lasso_loss(X, y, clf.coef_, clf.alpha))


sklearn_loss = np.array(sklearn_loss)
sklearn_time = np.array(sklearn_time)
min_loss = sklearn_loss.min()

fig, axarr = plt.subplots(2, 1, sharex=True)
# plt.title(f"Why losses sometimes seem to go back in time")
ax = axarr[0]
ax.semilogy(sklearn_time, sklearn_loss - min_loss, label='sklearn')
ax.set_ylabel(r"$f(\beta) - f(\hat\beta)$")

idx = np.argsort(sklearn_time)
ax = axarr[1]
ax.semilogy(sklearn_time[idx], sklearn_loss[idx] - min_loss, label='sklearn')
ax.set_xlabel("Time (s)")
ax.set_ylabel(r"$f(\beta) - f(\hat\beta)$")
ax.legend()
plt.show(block=False)
