"""
This codes verifies the sampling methodology. I.e. you can run 1000 samples several times or 
compare 1000 samples with more like 4000 samples
"""

import itertools
import pennylane as qml
import pennylane.numpy as pnp
import matplotlib.pyplot as plt
import seaborn as sns

def create_qnn(n_layers, n_qubits, n_gates, observable_coeffs, observable_ops, entangled=True):
    dev = qml.device('default.qubit', wires=n_qubits)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                if n_gates == 1:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                elif n_gates == 2:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                    qml.RZ(params[layer][qubit][1], wires=qubit)
                elif n_gates == 3:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                    qml.RZ(params[layer][qubit][1], wires=qubit)
                    qml.RY(params[layer][qubit][2], wires=qubit)

            if entangled:
                for qubit in range(n_qubits):
                    if n_qubits <= 1:
                        continue
                    next_qubit = (qubit + 1) % n_qubits
                    qml.CNOT(wires=[qubit, next_qubit])

        observable = qml.Hamiltonian(observable_coeffs, observable_ops)
        return qml.expval(observable)

    return circuit

def generate_parameter_samples(n_layers, n_qubits, n_samples, n_gates=2):
    samples = [pnp.random.uniform(0, 2*pnp.pi, size=(n_layers, n_qubits, n_gates)) for _ in range(n_samples)]
    return pnp.array(samples)

def calculate_hessian_norms(qnn, samples):
    hessian_norms = []
    hessian_fn = qml.jacobian(qml.grad(qnn))
    for i, params in enumerate(samples):
        # if i % 10 == 0:
        #     print(f'Sample {i}')
        flat_params = params.flatten()
        def cost_fn_flat(p_flat):
            p_reshaped = p_flat.reshape(params.shape)
            return qnn(p_reshaped)

        hessian_matrix = qml.jacobian(qml.grad(cost_fn_flat))(flat_params)
        spectral_norm = pnp.linalg.norm(hessian_matrix, ord=2)
        hessian_norms.append(spectral_norm)

    return hessian_norms

if __name__ == '__main__':
    results_data = []
    n_samples = 1000
    n_gates = 3
    n_qubits = 4
    n_layers = 5
    entanglement = False
    observable_coeffs = [1 / n_qubits] * n_qubits
    observable_ops = [qml.PauliZ(i) for i in range(n_qubits)] # Zi MEASURMENT ALL QUBITS

    for i in range(10):
        qnn = create_qnn(n_layers, n_qubits, n_gates, observable_coeffs, observable_ops, entangled=entanglement)
        samples = generate_parameter_samples(n_layers, n_qubits, n_samples, n_gates=n_gates)

        P = n_layers * n_qubits * n_gates
        norm_M = 1.0
        L_bound = P * norm_M

        hessian_norms = calculate_hessian_norms(qnn, samples)

        all_within_bound = all(norm <= L_bound for norm in hessian_norms)
        #print("--- Experiment Setup ---")
        # print(f"Number of Layers: {n_layers}, Number of Qubits: {n_qubits}, Number of Gates: {n_gates}, Total Parameters (P): {P}")
        # print(f"Theoretical L-Smoothness Bound (L <= P): {L_bound:.4f}")
        print(f"Largest Hessian Norm Sample {i}: {pnp.max(hessian_norms)}")
        #results_data.append((n_layers, n_qubits, n_gates, pnp.max(hessian_norms)))
        # print(hessian_norms)
