
# zne_utils.py
#
# Helper utilities for building a depolarizing noise model and
# running zero-noise extrapolation (ZNE) with Richardson extrapolation.
#
# This is a refactored version of the logic used in your Fashion-MNIST
# ZNE script. It is backend-agnostic so it can be reused across datasets.

import numpy as np
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory


def build_noise_model(base_eps: float):
    """Create a simple depolarizing noise model with strength base_eps."""
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ['x', 'rx', 'ry', 'rz', 'u1', 'u2', 'u3']
    )
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 2),
        ['cx']
    )
    return nm


def make_executor(base_eps: float, seed: int = 123, shots: int = 1024):
    """
    Wrap a noisy AerSimulator + QuantumInstance into a callable executor
    compatible with Mitiq's ZNE interface.
    """
    backend = AerSimulator(noise_model=build_noise_model(base_eps),
                           seed_simulator=seed)
    qinst = QuantumInstance(backend=backend,
                            shots=shots,
                            seed_transpiler=seed)

    def _executor(circuit):
        counts = qinst.execute(circuit).get_counts()
        return counts.get('0' * circuit.num_qubits, 0) / shots

    return _executor


def zne_expectation_zero(vqc, x, base_eps: float,
                         scale_factors=(1, 3),
                         seed: int = 123):
    """
    Given a trained VQC and a single input x, use ZNE to estimate
    the zero-noise expectation values for each class output circuit.

    Returns:
        energies_zne: np.ndarray of shape (n_classes,),
                      where lower energy \approx higher confidence.
    """
    executor = make_executor(base_eps, seed=seed)
    fold_global = folding.fold_global
    factory = RichardsonFactory(scale_factors=list(scale_factors), order=1)

    # Construct the list of circuits used inside the VQC's neural network
    qc_list = vqc._neural_network.construct_circuit(x)
    zne_energies = []

    for circ in qc_list:
        # ezero is probability of all-zeros outcome under ZNE
        ezero = zne.execute_with_zne(
            circ, executor, scale_noise=fold_global, factory=factory
        )
        ezero = float(np.clip(ezero, 1e-12, 1.0 - 1e-12))
        # Convert to Ising-style energy in [-0.5, 0.5]
        energy = (1.0 - ezero) / 2.0
        zne_energies.append(energy)

    return np.array(zne_energies, dtype=float)
