import qis
import matplotlib.pyplot as plt
import numpy as np

def compute_loss(H, state, x, x_opt, obj):

    obj_err = np.abs(qis.qis_objective(H, x) - obj)
    return obj_err

def max_entropy_QIS_data_collect(H, state, x0, optimal, objective,
                                 record_step_size = 10, repeats = 2000,
                                 tol = 1e-12, show_progress = False, disp = True):

    solver = qis.QISSolver(H, state, parallel=True)
    solver.cur = x0
    collect = []

    for steps in range(1, repeats+1):
        x = solver.cur.copy()
        if (steps-1) % record_step_size == 0:
            loss = compute_loss(H, state, x, optimal, objective)
            collect.append(loss)
        solver.step(normalize=True)
        error = qis.norm2(x - solver.cur)
        # store loss in collect
        if show_progress:
            qis.print_progress(steps, repeats,
                           prefix = f'QIS ({steps}/{repeats})',
                           suffix = 'Complete', length = 50)
        if error < tol:
            print()
            break
    
    collect.append(compute_loss(H, state, x, optimal, objective))

    if disp:
        print(f'Standard QIS method converges to a residual of {error} in {steps} iterations.')
    return collect

def max_entropy_dual_gd_data_collect(H, state, x0, optimal, objective,
                                     recorded_step_size=10, repeats = 2000,
                                     disp = True, eta = 1, tol = 1e-12, show_progress = True):
    """
    Gradient descent for the dual problem
    """
    x = x0
    collect = []

    for t in range(1, repeats+1):
        gt = np.array(qis.dual_grad(H, state, x))
        ft = qis.norm2(gt)
        if (t-1) % recorded_step_size == 0:
            loss = compute_loss(H, state, x, optimal, objective)
            collect.append(loss)
        if ft < tol:
            break
        else:
            x -= eta*gt
        
        if show_progress:
            qis.print_progress(t, repeats,
                           prefix = f'GDM ({t}/{repeats})',
                           suffix = 'Complete', length = 50)
    print()

    collect.append(compute_loss(H, state, x, optimal, objective))
    if disp:
        print(f'Dual GD method converges to a residual of {ft} in {t} iterations.')
    return collect

def qis_draw_comparison_qis_gd(disp=True):
# generate a random Gibbs state using the qis module
    n = 6
    if disp:
        print(f'\nGenerating random Gibbs states of {n} qubits.')

    H = qis.Hamiltonian.Local1D(n).aggressive_normalized() # not complete
    if disp:
        print(f'The Hamiltonian has {H.terms()} terms')
    
    np.random.seed(100)
    r = np.random.randn(H.terms())
    op = H.get(r)
    beta = 1
    state = qis.State.Gibbs(n, options = {'beta': beta, 'Hamiltonian': op})

# optimal solution
    optimal = - beta * r
    objective = qis.qis_objective(H, optimal)
    print(f'Opti objective: {objective}\n')

# QIS and GD
    xaxis = range(1,2101,100)

    gd_result = []

# QIS
    print('----- QIS -----')
    qis_result = max_entropy_QIS_data_collect(
        H, state, x0 = [0.0] * H.terms(), optimal = optimal,
        objective = objective, record_step_size = 100,
        repeats = 2000, disp = False, tol = 1e-12, show_progress = False)

# GD
    print('\n----- GD -----')
    gd_result = max_entropy_dual_gd_data_collect(
        H, state, x0 = [0.0] * H.terms(), optimal = optimal,
        objective = objective, recorded_step_size = 100,
        repeats = 2000, disp = False, eta = H.terms(), tol = 1e-12, show_progress = False)

    Hc = H.positive_fully_normalized().sum_complete()

# QIS_complete
    print('----- QIS_complete -----')
    qis_result_complete = max_entropy_QIS_data_collect(
        Hc, state, x0 = [0.0] * Hc.terms(), optimal = optimal,
        objective = objective, record_step_size = 100,
        repeats = 2000, disp = False, tol = 1e-12, show_progress = False)
    
# GD_complete
    print('\n----- GD_complete -----')
    gd_result_complete = max_entropy_dual_gd_data_collect(
        Hc, state, x0 = [0.0] * Hc.terms(), optimal = optimal,
        objective = objective, recorded_step_size = 100,
        repeats = 2000, disp = False, eta = Hc.terms(), tol = 1e-12, show_progress = False)

    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.plot(xaxis, qis_result, label='QIS', marker='.')
    plt.plot(xaxis, gd_result, label='GD', marker='^')
    plt.plot(xaxis, qis_result_complete, label='QIS_complete', marker = 'h', linestyle='dashed')
    plt.plot(xaxis, gd_result_complete, label='GD_complete', marker = '>', linestyle=':')
    plt.legend()
    plt.savefig('fig/QIS_GD_Compare.pdf', format='pdf', dpi=1200)

if __name__ == '__main__':
    qis_draw_comparison_qis_gd()
