import itertools
import numpy as np
import copy

from scipy.linalg import eigvalsh, expm, eig
from scipy.stats import entropy
from scipy.optimize import minimize
from qiskit.quantum_info import random_density_matrix
from qiskit.quantum_info.operators import Operator, Pauli

from numpy.linalg import pinv

class Hamiltonian:
    """The Hamiltonian class

    Members:
    ops: the array of Hamiltonian terms
    label: the name of Hamiltonian terms

    Constructors:
    Ising, Local1D, Transversal1D, Random, Discountinuous

    Normalizers:
    sum_complete: Add an operator so that the sum of all terms is I
    complete: Add I

    positive_shift, normalize, etc: normalize the terms so that they satisfy the
    conditions

    """

    ops = None # array of ndarray for operators of dimension dim
    label = None # label of ops
    category = None # type of Hamiltonian
    dim = None # dimension

    def __init__(self, dim = 4, category='Empty'):
        self.dim = dim
        self.category = category
        self.ops = []
        self.label = []

    def terms(self):
        return len(self.ops)

    @classmethod
    def Ising(cls, n = 2): # generate the Ising Hamiltonian
        H = cls(dim = 2 ** n)
        H.category = 'Ising'

        for i in range(n-1):
            lab = 'I' * i + 'ZZ' + 'I' * (n-i-2)
            H.ops.append(operator_pauli(lab))
            H.label.append(lab)

        for i in range(n):
            lab = 'I' * i + 'X' + 'I' * (n-i-1)
            H.ops.append(operator_pauli(lab))
            H.label.append(lab)

        return H

    @classmethod
    def Local1D(cls, n, loop = True): # generate 1D 2-local Hamiltonian
        H = cls(dim = 2 ** n)
        H.category = 'Local'
        pauli = ['X', 'Y', 'Z']

        for single in pauli:
            for t in range(n):
                lab = 'I' * t + single + 'I' * (n-t-1)
                H.ops.append(operator_pauli(lab))
                H.label.append(lab)

        for pair in itertools.product(pauli, repeat = 2):
            for t in range(n-1):
                lab = 'I' * t + pair[0] + pair[1] + 'I' * (n-t-2)
                H.ops.append(operator_pauli(lab))
                H.label.append(lab)
            if n > 2 and loop:
                lab = pair[1] + 'I' * (n-2) + pair[0]
                H.ops.append(operator_pauli(lab))
                H.label.append(lab)
        return H

    @classmethod
    def Transversal1D(cls, n, loop = True): # transversal 1D
        dim = 2 ** n
        H = cls(dim)
        H.category = 'Transversal'
        pauli = ['X', 'Y', 'Z']

        for single in pauli:
            op = sum([operator_pauli('I' * t + single + 'I' * (n-t-1))
                      for t in range(n)])
            H.ops.append(op)
            H.label.append(single)

        for pair in itertools.product(pauli, repeat = 2):
            op = sum([operator_pauli('I' * t + pair[0] + pair[1] + 'I' * (n-t-2))
                      for t in range(n-1)])
            if n > 2 and loop:
                op += operator_pauli(pair[1] + 'I' * (n-2) + pair[0])
            H.ops.append(op)
            H.label.append(pair[0] + pair[1])
        return H

    @classmethod
    def Random(cls, n = 4, terms=12): # generate a random Hamiltonian
        dim = 2 ** n
        H = cls(dim)
        H.category = 'Random'

        for i in range(terms):
            A = np.random.randn(dim, dim) + 1j * np.random.randn(dim, dim)
            H.ops.append(A.conjugate().transpose() + A)
            H.label.append('Rand-' + str(i))
        return H

    @classmethod
    def Discontinuous(cls): # generate the discontinuous example
        H = cls(dim=3)
        H.category = 'Discontinuous'
        H.ops.append(np.real(0.1 * np.array([[1,0,0],[0,1,0],[0,0,-1]],
                                            complex)))
        H.label.append('F1')
        H.ops.append(np.real(0.1 * np.array([[1,0,1],[0,1,1],[1,1,-1]],
                                            complex)))
        H.label.append('F2')
        return H

    def sum_complete(self): # Make sure the sum is I
        s = operator_sum(self.ops)
        self.ops.append(np.eye(self.dim) - s)
        self.label.append('Complete')
        return self

    def complete(self): # Add I
        self.ops.append(np.eye(self.dim))
        self.label.append('I')
        return self

    def positive_shift(self):
        for i in range(self.terms()):
            minev = min(eigvalsh(self.ops[i]))
            self.ops[i] -= minev * np.eye(self.dim)
        return self

    def positive_normalized(self): # shift and normalize each term
        for i in range(self.terms()):
            minev = min(eigvalsh(self.ops[i]))
            maxev = max(eigvalsh(self.ops[i]))
            self.ops[i] -= minev * np.eye(self.dim)
            self.ops[i] /= (maxev-minev)
        return self

    def positive_fully_normalized(self): # ops will be psd and the sum of ops = I
        self.positive_shift()
        maxev = max(eigvalsh(operator_sum(self.ops)))
        self.ops = list(self.ops / maxev)
        return self

    def normalized(self): # make sure that the abs of each term is at most I
        for i in range(self.terms()):
            scale = max(abs(eigvalsh(self.ops[i])))
            self.ops[i] /= scale
        return self

    def fully_normalized(self): # the sum of abs is at most I
        opsum = np.zeros_like(self.ops[0])
        for i in range(self.terms()):
            opp, opn = positive_negative_parts(self.ops[i])
            opsum += (opp + opn)

        scale = max(np.real(eigvalsh(opsum)))

        for i in range(self.terms()): # prevent converstion to ndarray
            self.ops[i] /= scale
        return self

    def aggressive_normalized(self): # assume the terms are bounded 1, shift by I, divide by 2m
        m = self.terms()
        for i in range(m):
            self.ops[i] = (self.ops[i] + np.eye(self.dim)) / (2.0 * m)
        return self

    def get(self, coeff): # get the Hamiltonian with coeff as coefficients
        return operator_sum(self.ops, coeff)

    def average(self, density):
        return average(self.ops, density)
pass

class State:
    """Quantum density matrix generateors

    dim: Hilbert space dimension
    data: Density matrix
    name: Label of the state

    """

    dim = None
    data = None
    name = None

    def __init__(self, dim, options = {}):
        self.dim = dim
        self.optinos = options

    @classmethod
    def GHZ(cls, n):
        dim = 2 ** n
        state = cls(dim)
        state.data = np.zeros([dim, dim], complex)
        state.name = 'GHZ'
        for pair in itertools.product([0, dim-1], repeat = 2):
            state.data[pair[0], pair[1]] = 0.5
        return state

    @classmethod
    def W(cls, n):
        dim = 2 ** n
        state = cls(dim)
        state.data = np.zeros([dim, 1], complex)
        state.name = 'W'
        for i in range(n):
            t = 2 ** i
            state.data[t, 0] = 1.0
        state.data = state.data @ state.data.transpose() / n
        return state

    @classmethod
    def random_state(cls, dim):
        state = cls(dim)
        state.name = 'Random'
        state.data = random_density_matrix(dim)._data
        return state

    @classmethod
    def Gibbs(cls, dim, options):
        state = cls(dim, options)
        beta = options.get('beta')
        op = options.get('Hamiltonian')
        state.name = 'Gibbs'
        state.data = expm(-beta * op)
        state.data /= state.data.trace()
        return state

    def set(self, data):
        self.dim = len(data)
        self.data = data
        self.name = 'Custom'
pass

def operator_pauli(paulistr): # return matrix for pauli string
    return Operator(Pauli(paulistr))._data

def average(ops, density): # return the array of average values
    return np.array(list(map(lambda op: np.real(np.trace(op @ density)), ops)))

def operator_sum(ops, coeff = None): # return the Hamiltonian defined by coeff
    terms = len(ops)
    if coeff is None:
        coeff = [1.0] * terms

    return sum(map(lambda x, y: x * y, coeff, ops))

def positive_negative_parts(op):
    phi, vec = eig(op)
    phi = np.real(phi)
    P, N = np.zeros_like(op), np.zeros_like(op)
    for i in range(len(phi)):
        v = np.array([vec[:,i]])
        if phi[i] >= 0:
            P += phi[i] * v.transpose() @ v.conjugate()
        else:
            N -= phi[i] * v.transpose() @ v.conjugate()
    return P, N

def norm2(x):
    return np.linalg.norm(x, ord=2)

hamiltonians = {
    'Ising': Hamiltonian.Ising,
    'Random': Hamiltonian.Random,
    'Local1D': Hamiltonian.Local1D,
    'Transversal1D': Hamiltonian.Transversal1D
}

states = {
    'Random': State.random_state,
    'W': State.W,
    'GHZ': State.GHZ,
    'Gibbs': State.Gibbs
}

class BregmanLegendreSolver:
    dim, ops, cur = [None] * 3
    parallel = True

    def __init__(self, H, parallel=True):
        self.dim = H.dim
        self.ops = H.ops
        self.cur = [0.0] * len(H.ops)
        self.parallel = parallel

    def terms(self):
        return len(self.ops)

    def L(self):
        print('Implement the Bregman-Legendre projection function!')

    def step(self, normalize=False):
        print('Implement the update step!')

class QISSolver(BregmanLegendreSolver):
    target = None

    def __init__(self, H, state, parallel):
        BregmanLegendreSolver.__init__(self, H, parallel)
        self.target = H.average(state.data)

    def L(self):
        hs = operator_sum(self.ops, self.cur)
        return expm(hs)

    def step(self, normalize=False):
        Y = self.L()
        if normalize:
            Y /= Y.trace()

        if self.parallel:
            vals = average(self.ops, Y)
            self.cur += np.log(self.target) - np.log(vals)
        else:
            vals = average(self.ops, Y)
            i = np.argmax(np.abs(vals - self.target))
            self.cur[i] += np.log(self.target[i]) - np.log(vals[i])
        return self

class AdaBoostSolver(BregmanLegendreSolver):
    def __init__(self, H, parallel=True):
        BregmanLegendreSolver.__init__(self, H, parallel)
        self.pos, self.neg = map(list,zip(*map(positive_negative_parts,
                                               self.ops)))
    def L(self):
        hs = operator_sum(self.ops, self.cur)
        return expm(hs)

    def step(self, normalize=False):
        Y = self.L()
        if self.parallel:
            valp = average(self.pos, Y)
            valn = average(self.neg, Y)
            self.cur += 0.5 * (np.log(valn) - np.log(valp))
        else:
            i = np.argmax(np.abs(average(self.ops, Y)))
            valp = average(self.pos, Y)
            valn = average(self.neg, Y)
            self.cur[i] += 0.5 * (np.log(valn[i]) -np.log(valp[i]))
        return self

def Z(op): # partition function
    return sum(np.exp(np.real(eigvalsh(op))))

def qis_objective(H, x): # the objective function of the max entropy problem
    p = np.exp(np.real(eigvalsh(H.get(x))))
    p /= sum(p)
    return entropy(p)

def dual_objective(H, state, x): # the objective function of the dual problem
    target = H.average(state.data)
    return np.log(Z(H.get(x))) - np.real(np.dot(target, x))

def dual_grad(H, state, x): # the gradient of the dual objective function
    target = H.average(state.data)
    op = expm(H.get(x))
    op /= np.real(op.trace())
    return list(map(lambda F, t: np.real(np.trace(op @ F)) - t,
                    H.ops, target))

def print_progress(iteration, total, prefix = '', suffix = '',
                   decimals = 1, length = 100, fill = '█', printEnd = "\r"):
    """
    Call in a loop to create terminal progress bar (from stackoverflow)
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print(f'\r{prefix} |{bar}| {percent}% {suffix}', end = printEnd)

    if iteration == total:
        print()

def max_entropy_QIS(H, state, x0, repeats = 2000, tol = 1e-12, show_progress = False):

    solver = QISSolver(H, state, parallel=True)
    solver.cur = x0

    for steps in range(1, repeats+1):
        x = solver.cur.copy()
        solver.step(normalize=True)
        error = norm2(x - solver.cur)
        if show_progress:
            print_progress(steps, repeats,
                           prefix = f'QIS ({steps}/{repeats})',
                           suffix = 'Complete', length = 50)
        if error < tol:
            print()
            break

    print(f'Standard QIS method converges to a residual of {error} in {steps} iterations.')
    return solver.cur

def anderson(g, x0, repeats = 2000, t = 2, beta = 1, m = 10, tol = 1e-12, disp=True):

    def delta(x):
        dx = x - g(x)
        return norm2(dx), dx

    steps = 0
    xt = x0
    nt, dt = delta(xt)

    N = len(x0)
    Xt = np.zeros((N,m))
    Rt = np.zeros((N,m))

    while steps < repeats and nt > tol:
        res = - dt

        if steps >= 1:
            k = (steps-1) % m
            Xt[:,k] = (xt-x_prev).reshape(len(xt))
            Rt[:,k] = (res-res_prev).reshape(len(xt))
        x_prev = xt.copy()
        res_prev = res.copy()

        if steps == 0:
            xt += beta*res
        else:
            styt = np.vdot(Rt[:,k],Xt[:,k])
            ytyt = np.vdot(Rt[:,k],Rt[:,k])
            beta = -styt/ytyt    # Barzilai–Borwein stepsize

            if t == 2: # Type-2
                Gamma = pinv(Rt.T@Rt, rcond = 1e-7)@(Rt.T@res)
            else: # Type-1
                Gamma = pinv(Xt.T@Rt, rcond = 1e-7)@(Xt.T@res)
            xt_bar = xt - Xt @ Gamma
            rt_bar = res - Rt @ Gamma
            xt = xt_bar + beta * rt_bar

        steps = steps+1
        nt, dt = delta(xt)

    if disp:
        print(f'Anderson Mixing of Type-{t} converges to a residual of {nt} in {steps} iterations.')
    return xt

# Anderson accelerated QIS
def max_entropy_QIS_AM(H, state, x0, repeats = 2000, beta = 1, m = 10, tol = 1e-12):

    solver = QISSolver(H, state, parallel=True)

    def g(x):
        solver.cur = x.copy()
        return solver.step(normalize=True).cur

    return anderson(g, x0, repeats, 2, beta, m, tol)

def max_entropy_QIS_AM_1(H, state, x0, repeats = 2000, beta = 1, m = 10, tol = 1e-12):

    solver = QISSolver(H, state, parallel=True)

    def g(x):
        solver.cur = x.copy()
        return solver.step(normalize=True).cur

    return anderson(g, x0, repeats, 1, beta, m, tol)

def max_entropy_dual_gd(H, state, x0, repeats = 2000, disp = False, eta = 1,
                        tol = 1e-12, show_progress = True):
    """
    Gradient descent for the dual problem
    """
    x = x0
    for t in range(1, repeats+1):
        gt = np.array(dual_grad(H, state, x))
        ft = norm2(gt)
        if ft < tol:
            break
        else:
            x -= eta*gt

        if show_progress:
            print_progress(t, repeats,
                           prefix = f'GDM ({t}/{repeats})',
                           suffix = 'Complete', length = 50)
    print()
    if disp:
        print(f'Dual GD method converges to a residual of {ft} in {t} iterations.')
    return x

def max_entropy_Nelder_Mead(H, state, x0, repeats = 2000, disp = False):

    def f(x):
        return dual_objective(H, state, x)

    def f_grad(x):
        return dual_grad(H, state, x)

    res = minimize(f, x0, method='Nelder-Mead',
                   options={'disp': disp, 'maxfev': repeats})
    return res.x

def max_entropy_L_BFGS_B(H, state, x0, repeats = 2000, disp = False):
    target = H.average(state.data)

    def f(x):
        return dual_objective(H, state, x)

    def f_grad(x):
        return dual_grad(H, state, x)

    res = minimize(f, x0, method='L-BFGS-B', jac=f_grad,
                   options={'disp': disp, 'maxls': 300,
                            'ftol': 1e-12, 'gtol': 1e-10, 'eps': 1e-10,
                            'maxfun': repeats, 'maxiter': repeats})
    return res.x

def l_bfgs(grad, x0, repeats = 2000, beta = 1, m = 10, tol = 1e-12, disp = False):
    """
    L-BFGS for gradient descent
    """

    def delta(x):
        grad_x = np.array(grad(x))
        return norm2(grad_x), grad_x

    steps = 0
    xt = x0
    nt, dt = delta(xt)

    N = len(x0)

    alpha = [None] * m

    while steps < repeats and nt > tol:

        if steps == 0:
            pt = -beta*dt
            S = []
            Y = []
            rho = []
        else:
            # Precompute quantites used in this iteration
            st = xt - xt_prev;
            yt = dt - dt_prev;
            styt = np.vdot(yt,st)
            ytyt = np.vdot(yt,yt)

            H_diag = beta
            if styt > 0:

                rhot = 1 / styt;
                H_diag = styt/ytyt;

                # Use information from last M iterations only
                if len(S) == m:
                    S.pop(0)
                    Y.pop(0)
                    rho.pop(0)

                S.append(st)
                Y.append(yt)
                rho.append(rhot)

            len_S = len(S)

            # L-BFGS two-loop recursion
            q = -dt;
            for i in range(len_S-1, -1, -1):
                alpha[i] = rho[i]* np.vdot(S[i],q)
                q -= alpha[i]*Y[i]
            r = H_diag*q;
            for i in range(len_S):
                be_i = rho[i]* np.vdot(Y[i],r)
                r += (alpha[i]-be_i)*S[i]
            pt = r

        xt_prev = xt.copy()
        dt_prev = dt.copy()

        xt += pt

        steps = steps + 1
        nt, dt = delta(xt)

    if disp:
        print(f'Custom L-BFGS method converges to a residual of {nt} in {steps} iterations.')
    return xt

def max_entropy_lbfgs(H, state, x0, repeats = 2000, eta = 1, beta = 1, m = 10, tol = 1e-12, disp = False):
    """
    L-BFGS for comparison
    """

    def grad(x):
        return eta * np.array(dual_grad(H, state, x))

    return l_bfgs(grad, x0, repeats, beta, m, tol, disp)

# Anderson mixing for gradient descent
def max_entropy_dual_AM(H, state, x0, repeats = 2000, eta = 1, beta = 1, m = 10, tol = 1e-12):

    def g(x):
        return x - eta * np.array(dual_grad(H, state, x))

    return anderson(g, x0, repeats, 2, beta, m, tol)

# Type-I Anderson mixing for gradient descent
def max_entropy_dual_AM_1(H, state, x0, repeats = 2000, eta = 1, beta = 1, m = 10, tol = 1e-12):

    def g(x):
        return x - eta * np.array(dual_grad(H, state, x))

    return anderson(g, x0, repeats, 1, beta, m, tol)

def min_partition_function_run(n = 4, category = 'Random', repeats = 100,
                               show_progress = True):
    H = hamiltonians.get(category)(n)
    Hp = H.fully_normalized() # for parallel update
    Hs = H.normalized() # for sequential update

    solver = AdaBoostSolver(Hp, parallel=True)
    print("Start the parallel solver.")
    for steps in range(1, repeats+1):
        solver.step(normalize=False)
        if show_progress:
            print_progress(steps, repeats,
                           prefix = f'AdaBoost ({steps}/{repeats})',
                           suffix = 'Complete', length = 50)
    print(f'Partition function: {Z(Hp.get(solver.cur))}')

    solseq = AdaBoostSolver(Hs, parallel=False)
    print("Start the sequantial solver.")

    for steps in range(1, repeats+1):
        solseq.step(normalize=False)
        if show_progress:
            print_progress(steps, repeats,
                           prefix = f'AdaBoost ({steps}/{repeats})',
                           suffix = 'Complete', length = 50)
    print(f'Partition function: {Z(Hs.get(solseq.cur))}')

def evaluate(H, state, x, x_opt, obj):
    # shift the identity freedom so that the solutions are comparable
    # x -= sum(x)/len(x)
    # x_opt -= sum(x_opt)/len(x)

    obj_err = np.abs(qis_objective(H, x) - obj)
    print(f'Objective value: {qis_objective(H, x)}')
    print(f'Error in solution : {norm2(x - x_opt)}')
    print(f'Error in objective: {obj_err}\n')

def max_entropy_run(disp=True):
    """
    Compare max entropy solvers
    """

    # Generate an random Gibbs state
    n = 6
    if disp:
        print(f'\nGenerating random Gibbs states of {n} qubits.')

    # H.positive_fully_normalized().sum_complete() # for parallel update
    H = Hamiltonian.Local1D(n).aggressive_normalized() # not complete
    if disp:
        print(f'The Hamiltonian has {H.terms()} terms')

    r = np.random.randn(H.terms())
    op = H.get(r)
    beta = 1
    state = State.Gibbs(n, options = {'beta': beta, 'Hamiltonian': op})

    # optimal solution
    optimal = - beta * r
    objective = qis_objective(H, optimal)
    print(f'Opti objective: {objective}\n')

    # qis

    print('----- QIS -----')
    x_qis = max_entropy_QIS(H, state, x0 = [0.0] * H.terms(),
                            repeats = 2000, tol = 1e-12, show_progress = True)
    evaluate(H, state, x_qis, optimal, objective)

    # am type-2
    print('----- AM-QIS -----')
    x_am = max_entropy_QIS_AM(H, state, x0 = [0.0] * H.terms(),
                              repeats = 40, tol = 1e-12)
    evaluate(H, state, x_am, optimal, objective)

    # am type-1
    # x_am1 = max_entropy_QIS_AM_1(H, state, x0 = [0.0] * H.terms(),
    #                              repeats = 40, tol = 1e-12)
    # evaluate(H, state, x_am1, optimal, objective)

    # dual = max_entropy_dual_gd(H, state, x0 = [0.0] * H.terms(),
    #                            repeats = 2000, eta = H.terms(), disp = True)
    # evaluate(H, state, dual, optimal, objective)

    # l-bfgs for dual optimization
    print('----- L-BFGS -----')
    dual_lbfgs = max_entropy_lbfgs(H, state, x0 = [0.0] * H.terms(),
                                   repeats = 40, eta = 2*H.terms(), disp = True)
    evaluate(H, state, dual_lbfgs, optimal, objective)

    # am type-2 for dual
    print('----- AM-Dual -----')
    dual_am = max_entropy_dual_AM(H, state, x0 = [0.0] * H.terms(),
                                  repeats = 40, eta = H.terms(), tol = 1e-12)
    evaluate(H, state, dual_am, optimal, objective)

    # am type-1 for dual
    # dual_am1 = max_entropy_dual_AM_1(H, state, x0 = [0.0] * H.terms(),
    #                                  repeats = 40, eta = H.terms(), tol = 1e-12)
    # evaluate(H, state, dual_am1, optimal, objective)

    # compare qis and gd of the same number of iterations
    print('----- QIS vs GD -----')
    x_qis_1K = max_entropy_QIS(H, state, x0 = [0.0] * H.terms(),
                               repeats = 1000, tol = 1e-10, show_progress = True)
    evaluate(H, state, x_qis_1K, optimal, objective)
    dual_1k = max_entropy_dual_gd(H, state, x0 = [0.0] * H.terms(),
                                  repeats = 1000, tol=1e-10,
                                  eta = H.terms(), disp = True)
    evaluate(H, state, dual_1k, optimal, objective)

    # x_neme = max_entropy_Nelder_Mead(H, state, x0 = [1.0] * H.terms())
    # print(f'NeMe objective: {dual_objective(H, state, x_neme)}')

if __name__ == '__main__':
    max_entropy_run()
