"""
Estimates k(n) on a shallow circuit letting you set the optimal learnign rate for a deeper one used for trianing in train_vqe_1q script
"""

import pennylane as qml
from pennylane import numpy as pnp
import itertools

N_QUBITS = 4
N_LAYERS = 6
N_GATES_PER_ROTATION = 3
N_SAMPLES = 1000

def create_vqe_circuit(n_layers, n_qubits):
    dev = qml.device('default.qubit', wires=n_qubits)
    obs = [qml.PauliZ(i) @ qml.PauliZ(i + 1) for i in range(n_qubits - 1)]
    obs.extend([qml.PauliX(i) for i in range(n_qubits)])

    coeffs = [-1.0] * (n_qubits - 1)
    coeffs.extend([0.5] * n_qubits)

    hamiltonian = qml.Hamiltonian(coeffs, obs)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                qml.RX(params[layer][qubit][0], wires=qubit)
                qml.RY(params[layer][qubit][1], wires=qubit)
                qml.RZ(params[layer][qubit][2], wires=qubit)

            for qubit in range(n_qubits - 1):
                qml.CNOT(wires=[qubit, qubit + 1])

        return qml.expval(hamiltonian)

    return circuit, hamiltonian


def generate_parameter_samples(n_layers, n_qubits, n_samples):
    param_shape = (n_layers, n_qubits, N_GATES_PER_ROTATION)
    samples = [pnp.random.uniform(0, 2 * pnp.pi, size=param_shape) for _ in range(n_samples)]
    return pnp.array(samples)


def calculate_max_hessian_norm(qnode, samples):
    hessian_norms = []
    total_params = samples[0].size
    param_shape = samples[0].shape

    print(f"Starting Hessian norm calculation for {len(samples)} samples...")
    for i, params in enumerate(samples):
        flat_params = params.flatten()
        def cost_fn_flat(p_flat):
            p_reshaped = p_flat.reshape(param_shape)
            return qnode(p_reshaped)

        hessian_matrix = qml.jacobian(qml.grad(cost_fn_flat))(flat_params)
        spectral_norm = pnp.linalg.norm(hessian_matrix, ord=2)
        hessian_norms.append(spectral_norm)

        if (i + 1) % 50 == 0:
            print(f"  ...processed {i+1}/{len(samples)} samples.")

    return pnp.max(hessian_norms)


if __name__ == '__main__':
    print("--- VQE Landscape Analysis for Optimal Learning Rate ---")
    print(f"Circuit details: {N_QUBITS} qubits, {N_LAYERS} layers.")

    vqe_circuit, hamiltonian = create_vqe_circuit(N_LAYERS, N_QUBITS)
    print("\nExact ground state energy (for reference):")
    print(qml.eigvals(hamiltonian)[0])


    param_samples = generate_parameter_samples(N_LAYERS, N_QUBITS, N_SAMPLES)

    L_max = calculate_max_hessian_norm(vqe_circuit, param_samples)

    eta_optimal = 1.0 / L_max

    print("\n--- Results ---")
    print(f"Estimated Maximum Curvature (L_max): {L_max:.4f}")
    print(f"Calculated Optimal Learning Rate (η ≈ 1/L_max): {eta_optimal:.4f}")
    print("\nUse this optimal learning rate in the 'train_vqe_demo.py' script.")