# urbansound8k_zne_utils.py
#
# ZNE utilities for UrbanSound8K distillation.

import numpy as np
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory


def build_noise_model(base_eps: float):
    """Create a simple depolarizing noise model with strength base_eps."""
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ['x', 'rx', 'ry', 'rz', 'u1', 'u2', 'u3']
    )
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 2),
        ['cx']
    )
    return nm


def make_executor(base_eps: float, seed: int = 123, shots: int = 1024):
    """Wrap a noisy AerSimulator into a callable executor for Mitiq ZNE."""
    backend = AerSimulator(noise_model=build_noise_model(base_eps),
                           seed_simulator=seed)
    qinst = QuantumInstance(backend=backend,
                            shots=shots,
                            seed_transpiler=seed)

    def _executor(circuit):
        counts = qinst.execute(circuit).get_counts()
        return counts.get('0' * circuit.num_qubits, 0) / shots

    return _executor


def zne_expectation_zero(vqc, x, base_eps: float,
                         scale_factors=(1, 3),
                         seed: int = 123):
    """Estimate zero-noise expectations for each class circuit of vqc at input x.

    Returns:
        np.ndarray of shape (n_classes,) with energy-like values in [0, 0.5].
    """
    executor = make_executor(base_eps, seed=seed)
    fold_global = folding.fold_global
    factory = RichardsonFactory(scale_factors=list(scale_factors), order=1)

    qc_list = vqc._neural_network.construct_circuit(x)
    zne_energies = []

    for circ in qc_list:
        ezero = zne.execute_with_zne(
            circ, executor, scale_noise=fold_global, factory=factory
        )
        ezero = float(np.clip(ezero, 1e-12, 1.0 - 1e-12))
        # Convert probability of |0...0> into an Ising-style "energy"
        energy = (1.0 - ezero) / 2.0
        zne_energies.append(energy)

    return np.array(zne_energies, dtype=float)
