"""
Estimates k(n) on a shallow circuit letting you set the optimal learnign rate for a deeper one used for trianing in train_vqe_1q script
This if for the 1 qubit circuit. The VQE_OPTLR can be used for any n > 1
"""

import pennylane as qml
from pennylane import numpy as pnp
import itertools

N_QUBITS = 1
N_LAYERS = 10
N_GATES_PER_ROTATION = 3
N_SAMPLES = 1000

def create_vqe_circuit(n_layers, n_qubits):
    dev = qml.device('default.qubit', wires=n_qubits)
    coeffs = [0.6, 0.8]
    obs = [qml.PauliX(0), qml.PauliZ(0)]
    hamiltonian = qml.Hamiltonian(coeffs, obs)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                qml.RX(params[layer][qubit][0], wires=qubit)
                qml.RY(params[layer][qubit][1], wires=qubit)
                qml.RZ(params[layer][qubit][2], wires=qubit)

        return qml.expval(hamiltonian)

    return circuit, hamiltonian


def generate_parameter_samples(n_layers, n_qubits, n_samples):
    pnp.random.seed(42)
    param_shape = (n_layers, n_qubits, N_GATES_PER_ROTATION)
    samples = [pnp.random.uniform(0, 2 * pnp.pi, size=param_shape) for _ in range(n_samples)]
    return pnp.array(samples)


def calculate_max_hessian_norm(qnode, samples):
    hessian_norms = []
    param_shape = samples[0].shape

    print(f"Starting Hessian norm calculation for {len(samples)} samples...")
    for i, params in enumerate(samples):
        flat_params = params.flatten()
        def cost_fn_flat(p_flat):
            p_reshaped = p_flat.reshape(param_shape)
            return qnode(p_reshaped)

        hessian_matrix = qml.jacobian(qml.grad(cost_fn_flat))(flat_params)
        spectral_norm = pnp.linalg.norm(hessian_matrix, ord=2)
        hessian_norms.append(spectral_norm)

        if (i + 1) % 50 == 0:
            print(f"  ...processed {i+1}/{len(samples)} samples.")

    return pnp.max(hessian_norms)


if __name__ == '__main__':
    print("--- 1-Qubit VQE Landscape Analysis ---")
    print(f"Circuit details: {N_QUBITS} qubit, {N_LAYERS} layers.")

    vqe_circuit, hamiltonian = create_vqe_circuit(N_LAYERS, N_QUBITS)
    print("\nExact ground state energy (for reference):")
    coeffs = [0.6, 0.8]
    exact_energy = -pnp.sqrt(coeffs[0]**2 + coeffs[1]**2)
    print(exact_energy)

    param_samples = generate_parameter_samples(N_LAYERS, N_QUBITS, N_SAMPLES)
    L_max = calculate_max_hessian_norm(vqe_circuit, param_samples)
    eta_optimal = 1.0 / L_max

    print("\n--- Results ---")
    print(f"Estimated Maximum Curvature (L_max): {L_max:.4f}")
    print(f"Calculated Optimal Learning Rate (η ≈ 1/L_max): {eta_optimal:.4f}")

