"""
This code trains a 1 Qubit VQE with varios number of layers uisng the estiamted optimal lr and comparing it with other lrs
You can switch the optimizer or use the train_vqe for more qubit examples
to estimate the optimal lr use the vqe_lr_1q script
"""


import pennylane as qml
from pennylane import numpy as pnp
import matplotlib.pyplot as plt
import seaborn as sns

N_QUBITS = 1
N_LAYERS = 40 
N_GATES_PER_ROTATION = 3 
EPOCHS = 400

ETA_OPTIMAL = 0.021056
ETA_HIGH = ETA_OPTIMAL * 5.0
ETA_LOW = ETA_OPTIMAL * 0.2
ETA_STANDARD = 0.01


def create_vqe_circuit(n_layers, n_qubits):
    dev = qml.device('default.qubit', wires=n_qubits)

    coeffs = [0.6, 0.8]
    obs = [qml.PauliX(0), qml.PauliZ(0)]
    hamiltonian = qml.Hamiltonian(coeffs, obs)

    exact_eigenvalue = -pnp.sqrt(coeffs[0] ** 2 + coeffs[1] ** 2)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                qml.RX(params[layer][qubit][0], wires=qubit)
                qml.RY(params[layer][qubit][1], wires=qubit)
                qml.RZ(params[layer][qubit][2], wires=qubit)
        return qml.expval(hamiltonian)

    return circuit, exact_eigenvalue


def train_vqe(learning_rate, cost_fn):
    print(f"\n--- Training 1-Qubit VQE with Learning Rate: {learning_rate:.4f} ---")

    pnp.random.seed(42) 
    param_shape = (N_LAYERS, N_QUBITS, N_GATES_PER_ROTATION)
    initial_params = pnp.random.uniform(0, 2 * pnp.pi, size=param_shape, requires_grad=True)

    # optimizer = qml.AdamOptimizer(stepsize=learning_rate)
    optimizer = qml.GradientDescentOptimizer(stepsize=learning_rate)
    params = initial_params
    energy_history = []

    for epoch in range(EPOCHS):
        params, cost = optimizer.step_and_cost(cost_fn, params)
        energy_history.append(cost)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch + 1:3d}: Cost = {cost:.8f}")

    return energy_history


if __name__ == '__main__':
    vqe_circuit, exact_eigenvalue = create_vqe_circuit(N_LAYERS, N_QUBITS)

    history_optimal = train_vqe(ETA_OPTIMAL, vqe_circuit)
    history_high = train_vqe(ETA_HIGH, vqe_circuit)
    history_low = train_vqe(ETA_LOW, vqe_circuit)
    history_standard = train_vqe(ETA_STANDARD, vqe_circuit)

    print(f'Qubits: {N_QUBITS}, Layers: {N_LAYERS}')
    print(f'Optimal {ETA_OPTIMAL}, {history_optimal}')
    print(f'5x Optimal {ETA_HIGH}, {history_high}')
    print(f'0.2 Optimal {ETA_LOW}, {history_low}')
    print(f'Standard Adam/SGD {ETA_STANDARD}, {history_standard}')

    # # 3. Plot the results
    # sns.set_theme(style="whitegrid")
    # plt.figure(figsize=(12, 7))
    #
    # plt.plot(history_optimal, label=f'Optimal η = {ETA_OPTIMAL:.4f}', color='royalblue', linewidth=2.5)
    # plt.plot(history_high, label=f'High η = {ETA_HIGH:.4f}', color='indianred', linestyle='--', linewidth=2)
    # plt.plot(history_low, label=f'Low η = {ETA_LOW:.4f}', color='purple', linewidth=2)
    # plt.plot(history_standard, label=f'Standard η = {ETA_STANDARD:.4f}', color='mediumseagreen', linestyle=':', linewidth=2)


    # plt.axhline(y=exact_eigenvalue, color='black', linestyle='-.', linewidth=1.5,
    #             label=f'Exact Ground State = {exact_eigenvalue:.4f}')
    #
    # plt.title(f'1-Qubit VQE Convergence ({N_LAYERS} Layers)', fontsize=16, fontweight='bold')
    # plt.xlabel('Optimization Steps (Epochs)', fontsize=12)
    # plt.ylabel('Energy (Expectation Value)', fontsize=12)
    # plt.legend(fontsize=11)
    # plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    # plt.tight_layout()
    #
    # print("\nDisplaying convergence plot...")
    # plt.show()
