import qis
import matplotlib.pyplot as plt
import numpy as np

experiments = 3

def compute_loss(H, state, x, x_opt, obj):
    obj_err = np.abs(qis.qis_objective(H, x) - obj)
    return obj_err
 
def anderson_data_collect(g, H, state, x0, optimal, objective, record = 10,
                          repeats = 2000, t = 2, beta = 1, m = 10, tol = 1e-12,
                          disp = False, bb = False):

    def delta(x):
        dx = x - g(x)
        return qis.norm2(dx), dx

    steps = 0
    xt = x0
    nt, dt = delta(xt)
    collect = []

    N = len(x0)
    Xt = np.zeros((N,m))
    Rt = np.zeros((N,m))

    while steps < repeats:
        if steps % record == 0:
            collect.append(compute_loss(H, state, xt, optimal, objective))
        res = - dt

        if steps >= 1:
            k = (steps-1) % m
            Xt[:,k] = (xt-x_prev).reshape(len(xt))
            Rt[:,k] = (res-res_prev).reshape(len(xt))
        x_prev = xt.copy()
        res_prev = res.copy()

        if steps == 0:
            xt += beta*res
        else:
            styt = np.vdot(Rt[:,k],Xt[:,k])
            ytyt = np.vdot(Rt[:,k],Rt[:,k])
            if bb:
                beta = -styt/ytyt    # Barzilai–Borwein stepsize

            if t == 2: # Type-2
                Gamma = qis.pinv(Rt.T@Rt, rcond=1e-7)@(Rt.T@res)
            else: # Type-1
                Gamma = qis.pinv(Xt.T@Rt, rcond=1e-7)@(Xt.T@res)
            xt_bar = xt - Xt @ Gamma
            rt_bar = res - Rt @ Gamma

            xt = xt_bar + beta * rt_bar
            
        steps = steps+1
        nt, dt = delta(xt)

    collect.append(compute_loss(H, state, xt, optimal, objective))

    if disp:
        print(f'Anderson Mixing of Type-{t} converges to a residual of {nt} in {steps} iterations.')
    return collect

def l_bfgs_data_collect(grad, H, state, x0, optimal, objective, mod = 10, repeats = 2000,
                        beta = 1, m = 10, tol = 1e-12, disp = False, bb = False):
    """
    L-BFGS for gradient descent
    """

    def delta(x):
        grad_x = np.array(grad(x))
        return qis.norm2(grad_x), grad_x

    steps = 0
    xt = x0
    nt, dt = delta(xt)
    collect = []

    N = len(x0)

    alpha = [None] * m

    while steps < repeats:
        if steps % mod == 0:
            collect.append(compute_loss(H, state, xt, optimal, objective))
        if steps == 0:
            pt = -beta*dt
            S = []
            Y = []
            rho = []
        else:
            # Precompute quantites used in this iteration
            st = xt - xt_prev
            yt = dt - dt_prev
            styt = np.vdot(yt,st)
            ytyt = np.vdot(yt,yt)

            H_diag = beta
            if styt > 0:

                rhot = 1 / styt
                if bb:
                    H_diag = styt/ytyt

                # Use information from last M iterations only
                if len(S) == m:
                    S.pop(0)
                    Y.pop(0)
                    rho.pop(0)

                S.append(st)
                Y.append(yt)
                rho.append(rhot)

            len_S = len(S)

            # L-BFGS two-loop recursion
            q = -dt;
            for i in range(len_S-1, -1, -1):
                alpha[i] = rho[i]* np.vdot(S[i],q)
                q -= alpha[i]*Y[i]
            r = H_diag*q;
            for i in range(len_S):
                be_i = rho[i]* np.vdot(Y[i],r)
                r += (alpha[i]-be_i)*S[i]
            pt = r

        xt_prev = xt.copy()
        dt_prev = dt.copy()

        xt += pt
        
        steps = steps + 1
        nt, dt = delta(xt)
        
    collect.append(compute_loss(H, state, xt, optimal, objective))

    if disp:
        print(f'Custom L-BFGS method converges to a residual of {nt} in {steps} iterations.')
    return collect

def max_entropy_QIS_AM_data_collect(H, state, x0, optimal, objective, mod = 10,
                                    repeats = 2000, beta = 1, m = 10, tol = 1e-12, bb = False):

    solver = qis.QISSolver(H, state, parallel=True)

    def g(x):
        solver.cur = x.copy()
        return solver.step(normalize=True).cur

    return anderson_data_collect(g, H, state, x0, optimal, objective, mod,
                                 repeats, 2, beta, m, tol, bb = bb)

def max_entropy_lbfgs_data_collect(H, state, x0, optimal, objective, mod = 10,
                                   repeats = 2000, eta = 1, beta = 1, m = 10,
                                   tol = 1e-12, disp = False, bb = False):
    """
    L-BFGS for comparison
    """

    def grad(x):
        return eta * np.array(qis.dual_grad(H, state, x))

    return l_bfgs_data_collect(grad, H, state, x0, optimal, objective, mod,
                               repeats, beta, m, tol, disp, bb = bb)

def max_entropy_draw_AM_LBFGS(disp=False):
    """
    Compare max entropy solvers
    """

    am_qis = []
    bfgs_gd = []
    am_qis_complete = []
    bfgs_gd_complete = []

    am_qis_bb = []
    bfgs_gd_bb = []
    am_qis_complete_bb = []
    bfgs_gd_complete_bb = []

    seeds = range(experiments)

    for i in range(experiments):
        # Generate an random Gibbs state
        n = 6
        if disp:
            print(f'\nGenerating random Gibbs states of {n} qubits.')

        np.random.seed(seeds[i])
        H = qis.Hamiltonian.Local1D(n).aggressive_normalized() # not complete
        if disp:
            print(f'The Hamiltonian has {H.terms()} terms')

        r = np.random.randn(H.terms())
        op = H.get(r)
        beta = 1
        state = qis.State.Gibbs(n, options = {'beta': beta, 'Hamiltonian': op})

        # optimal solution
        optimal = - beta * r
        objective = qis.qis_objective(H, optimal)
        # print(f'Opti objective: {objective}\n')

        # am type-2
        print('----- AM-QIS -----')
        am_qis.append(max_entropy_QIS_AM_data_collect(
            H, state, [0.0] * H.terms(), optimal, objective, mod = 1,
            repeats = 40, tol = 1e-18))

        am_qis_bb.append(max_entropy_QIS_AM_data_collect(
            H, state, [0.0] * H.terms(), optimal, objective, mod = 1,
            repeats = 10, tol = 1e-18, bb = True))

        # l-bfgs
        print('----- L-BFGS -----')
        bfgs_gd.append(max_entropy_lbfgs_data_collect(
            H, state, [0.0] * H.terms(), optimal, objective, mod = 1,
            repeats = 40, eta = H.terms(), tol = 1e-18, disp = False))

        bfgs_gd_bb.append(max_entropy_lbfgs_data_collect(
            H, state, [0.0] * H.terms(), optimal, objective, mod = 1,
            repeats = 10, eta = H.terms(), tol = 1e-18, disp = False, bb = True))

        Hc = H.sum_complete()

        # am type-2 with sum completion
        print('----- AM-QIS with sum completion -----')
        am_qis_complete.append(max_entropy_QIS_AM_data_collect(
            Hc, state, [0.0] * Hc.terms(), optimal, objective, mod = 1,
            repeats = 40, tol = 1e-18))

        am_qis_complete_bb.append(max_entropy_QIS_AM_data_collect(
            Hc, state, [0.0] * Hc.terms(), optimal, objective, mod = 1,
            repeats = 10, tol = 1e-18, bb = True))

        # l-bfgs with sum completion
        print('----- L-BFGS with sum completion -----')
        bfgs_gd_complete.append(max_entropy_lbfgs_data_collect(
            Hc, state, [0.0] * Hc.terms(), optimal, objective, mod = 1,
            repeats = 40, eta = Hc.terms(), disp = False))

        bfgs_gd_complete_bb.append(max_entropy_lbfgs_data_collect(
            Hc, state, [0.0] * Hc.terms(), optimal, objective, mod = 1,
            repeats = 10, eta = Hc.terms(), disp = False, bb = True))

    am_qis_mean = np.mean(am_qis, axis = 0)
    bfgs_gd_mean = np.mean(bfgs_gd, axis = 0)
    am_qis_complete_mean = np.mean(am_qis_complete, axis = 0)
    bfgs_gd_complete_mean = np.mean(bfgs_gd_complete, axis = 0)

    am_qis_bb_mean = np.mean(am_qis_bb, axis = 0)
    bfgs_gd_bb_mean = np.mean(bfgs_gd_bb, axis = 0)
    am_qis_complete_bb_mean = np.mean(am_qis_complete_bb, axis = 0)
    bfgs_gd_complete_bb_mean = np.mean(bfgs_gd_complete_bb, axis = 0)
    # plot

    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.plot(am_qis_mean, label='AM-QIS', marker='o')
    plt.plot(bfgs_gd_mean, label='L-BFGS-GD', marker='^')
    plt.plot(am_qis_complete_mean, label='AM-QIS_complete', marker='h', linestyle='--')
    plt.plot(bfgs_gd_complete_mean, label='L-BFGS-GD_complete', marker='>', linestyle = ':')

    plt.legend()
    plt.savefig('fig/AM_BFGS.pdf', format = 'pdf', dpi = 1200)
    print("Save figure for AM vs BFGS.")

    plt.clf()

    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.plot(am_qis_bb_mean, label='AM-QIS', marker='o')
    plt.plot(bfgs_gd_bb_mean, label='L-BFGS-GD', marker='^')
    plt.plot(am_qis_complete_bb_mean, label='AM-QIS_complete', marker='h', linestyle='--')
    plt.plot(bfgs_gd_complete_bb_mean, label='L-BFGS-GD_complete', marker='>', linestyle = ':')

    plt.legend()
    plt.savefig('fig/AM_BFGS_BB.pdf', format = 'pdf', dpi = 1200)
    print("Save figure for AM vs BFGS with BB.")

if __name__ == '__main__':
    max_entropy_draw_AM_LBFGS()
