"""
This code trains an n Qubit VQE with variuos number of layers uisng the estiamted optimal lr and comparing it with other lrs
You can switch the optimizer or use the train_vqe for more qubit examples
to estimate the optimal lr use the vqe_lr_1q script
"""

import pennylane as qml
from pennylane import numpy as pnp
import matplotlib.pyplot as plt
import seaborn as sns

N_QUBITS = 4
N_LAYERS = 8 
N_GATES_PER_ROTATION = 3 
EPOCHS = 400 

ETA_OPTIMAL = 0.102466
ETA_HIGH = ETA_OPTIMAL * 5.0
ETA_LOW = ETA_OPTIMAL * 0.2
ETA_STANDARD = 0.01

def create_vqe_circuit(n_layers, n_qubits):
    dev = qml.device('default.qubit', wires=n_qubits)

    obs = [qml.PauliZ(i) @ qml.PauliZ(i + 1) for i in range(n_qubits - 1)]
    obs.extend([qml.PauliX(i) for i in range(n_qubits)])

    coeffs = [-1.0] * (n_qubits - 1)
    coeffs.extend([0.5] * n_qubits)

    hamiltonian = qml.Hamiltonian(coeffs, obs)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                qml.RX(params[layer][qubit][0], wires=qubit)
                qml.RY(params[layer][qubit][1], wires=qubit)
                qml.RZ(params[layer][qubit][2], wires=qubit)
            for qubit in range(n_qubits - 1):
                qml.CNOT(wires=[qubit, qubit + 1])
        return qml.expval(hamiltonian)

    return circuit, hamiltonian


def train_vqe(learning_rate, cost_fn):
    print(f"\n--- Training VQE with Learning Rate: {learning_rate:.4f} ---")

    pnp.random.seed(42)
    param_shape = (N_LAYERS, N_QUBITS, N_GATES_PER_ROTATION)
    initial_params = pnp.random.uniform(0, 2 * pnp.pi, size=param_shape, requires_grad=True)

    # optimizer = qml.AdamOptimizer(stepsize=learning_rate)
    optimizer = qml.GradientDescentOptimizer(stepsize=learning_rate)

    params = initial_params
    energy_history = []

    for epoch in range(EPOCHS):
        params, cost = optimizer.step_and_cost(cost_fn, params)
        energy_history.append(cost)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch + 1:3d}: Cost = {cost:.8f}")

    return energy_history


if __name__ == '__main__':
    vqe_circuit, hamiltonian = create_vqe_circuit(N_LAYERS, N_QUBITS)
    exact_eigenvalue = qml.eigvals(hamiltonian)[0]

    history_optimal = train_vqe(ETA_OPTIMAL, vqe_circuit)
    history_high = train_vqe(ETA_HIGH, vqe_circuit)
    history_low = train_vqe(ETA_LOW, vqe_circuit)
    history_standard = train_vqe(ETA_STANDARD, vqe_circuit)

    print(f'Qubits: {N_QUBITS}, Layers: {N_LAYERS}')
    print(f'Optimal {ETA_OPTIMAL}, {history_optimal}')
    print(f'5x Optimal {ETA_HIGH}, {history_high}')
    print(f'0.2 Optimal {ETA_LOW}, {history_low}')
    print(f'Standard Adam/SGD {ETA_STANDARD}, {history_standard}')

    # 3. Plot the results
    # sns.set_theme(style="whitegrid")
    # plt.figure(figsize=(12, 7))
    #
    # plt.plot(history_optimal, label=f'Optimal η ≈ {ETA_OPTIMAL:.4f}', color='royalblue', linewidth=2.5)
    # plt.plot(history_high, label=f'High η = {ETA_HIGH:.4f}', color='indianred', linestyle='--', linewidth=2)
    # plt.plot(history_low, label=f'Low η = {ETA_LOW:.4f}', color='Purple', linewidth=2)
    # plt.plot(history_standard, label=f'Standard η = {ETA_STANDARD:.4f}', color='mediumseagreen', linestyle=':', linewidth=2)
    #
    # # Plot the exact ground state energy as a horizontal line
    # plt.axhline(y=exact_eigenvalue, color='black', linestyle='-.', linewidth=1.5,
    #             label=f'Exact Ground State = {exact_eigenvalue:.4f}')
    #
    # plt.title('VQE Convergence with a Deeper Circuit (15 Layers)', fontsize=16, fontweight='bold')
    # plt.xlabel('Optimization Steps (Epochs)', fontsize=12)
    # plt.ylabel('Energy (Expectation Value)', fontsize=12)
    # plt.xlim(0,40)
    # plt.legend(fontsize=11)
    # plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    # plt.tight_layout()
    #
    # print("\nDisplaying convergence plot...")
    # plt.show()


