#%%
import numpy as np
import cvxpy as cp

import matplotlib.pyplot as plt
import time

from Toy_utils import Monitor

#%%
def run(n = 100, P = 1, nIter = 500, verbose=2):
    alphak = 1e-2 # 1e-3

    resultfile = f'D:/Results/ICML24/Toy_Convex/Toy_P={P}_n={n}_GAM_alphak={alphak}.csv'

    F   = lambda x, y: (y[:n] - 2) @ (x - 1) + np.sum(np.square(y[n:] + 3))
    dxF = lambda x, y: y[:n] - 2
    dyF = lambda x, y: np.concatenate([x - 1, 2 * (y[n:] + 3)])

    f   = lambda x, y: .5 * np.sum(np.square(y[:n])) - x @ y[:n] + np.sum(y[n:])
    dyf = lambda x, y: np.concatenate([y[:n] - x, np.ones(n)])

    if P == 1: 
        h = lambda x : np.sum(x)
        dh = lambda x : np.ones(n)
    elif P == 3:
        h = lambda x : np.sum(x ** 3)
        dh = lambda x: 3 * x ** 2

    g    = lambda x, y: h(x) + np.sum(y)
    dxg  = lambda x, y: dh(x)
    dyg  = lambda x, y: np.ones(2 * n)

    proj1 = lambda x, b : x - (sum(x) + b) / len(x)  # proj1(x) + b = 0
    yx   = lambda x, y: np.concatenate((x+1, proj1(y[n:], sum(x+1)+h(x))))
    y2yx = lambda x, y: np.linalg.norm( y - yx(x, y) )

    # opt: (1, 2, -3)
    xopt = 1 * np.ones(n)

    metric_x = lambda x, y : np.linalg.norm(x - xopt) / np.linalg.norm(xopt)
    metric_y = lambda x, y : y2yx(x, y) / np.linalg.norm(yx(x, y))

    # initial guess
    xk = 0 * np.ones(n)
    yk = 0 * np.concatenate([np.ones(n), np.ones(n)])
    T  = 0 
    monitor = Monitor()
    monitor.append({
        "k": 0, "time": T,
        "F": F(xk, yk), "f": f(xk, yk), "g": g(xk, yk),
        "dx": metric_x(xk, yk), "dy": metric_y(xk, yk), 
    })

    yk = proj1(yk, h(xk))

    # define lower level problem 
    y = cp.Variable(n * 2)
    x = cp.Parameter(n)
    hx = cp.Parameter()
    obj = cp.Minimize( .5 * cp.sum_squares(y[:n]) - x @ y[:n] + cp.sum(y[n:]) )
    con = [hx + cp.sum(y) == 0]
    prob = cp.Problem(obj, con)

    M  = np.zeros((2*n + 1, 2*n + 1))

    for k in range(nIter):
        t0 = time.time()

        # solve lower level problem 
        x.value = xk
        hx.value = h(xk)
        prob.solve(cp.CLARABEL) # cp.Clarabel
        yk = y.value   
        # vk = con[0].dual_value 

        # estimate gk
        M1 = np.diag(np.concatenate([np.ones(n), np.zeros(n)]))
        M2 = np.ones((2*n, 1))
        M  = np.concatenate([
            np.concatenate([M1, M2], axis = 1),
            np.concatenate([M2.T, np.zeros((1, 1))], axis = 1)
        ], axis = 0)

        N1 = np.concatenate([-np.eye(n), np.zeros((n, n))], axis = 1)
        N2 = np.zeros((n, 1))
        N2[:, 0] = dh(xk)
        N  = np.concatenate([N1, N2], axis = 1)
        # d = - np.linalg.pinv(M) @ N.T
        d = - np.linalg.lstsq(M, N.T)[0]
        # d = - sp.linalg.pinv(M) @ N.T

        # solve it by hand
        # if P == 1:
        #     d = np.zeros((2*n + 1, n))
        #     d[:n, :] = -np.eye(n)
        #     d = -d 
        # elif P == 3:
        #     d = np.zeros((2*n + 1, n))
        #     d[:n, :] = -np.eye(n)
        #     tmp = (dh(xk)+1) / n
        #     d[n:(2*n), :] = tmp.repeat(n).reshape([n, n]).T 
        #     d = -d

        d1 = d[:2*n, :]
        grad = d1.T @ dyF(xk, yk) 

        gd = dxF(xk, yk) + grad 

        # update
        # xkp = xk - alphak / (1 + k) ** 0.3 * gd
        xkp = xk - alphak * gd

        T += time.time() - t0

        monitor.append({
            "k": k, "time": T,
            "F": F(xk, yk), "f": f(xk, yk), "g": g(xk, yk), 
            "dF" : np.linalg.norm(gd),
            "dx": metric_x(xk, yk), 
            "dy": metric_y(xk, yk),
        })
        if verbose >= 2: print(f'{k:5d}-th iter: Fk = {F(xk, yk):.2f}, dx = {monitor.dx[-1]:.2f}, gd = {np.linalg.norm(gd):6.2f}, time = {T:.2f}')

        xk = xkp

    if verbose >= 1:
        plt.semilogy(monitor.time, monitor.dx)
        plt.show()

    monitor.save_csv(resultfile)

# %%
if __name__ == "__main__":
    n = 1000
    for P in [1, 3]:
        print(f"Experiment with n = {n}, P = {P}")
        run(n = n, P = P, nIter = 50, verbose=1)