"""
urbansound8k_zne.py
-------------------
Zero-Noise Extrapolation with a Variational Quantum Classifier (VQC)
on the UrbanSound8K dataset.

▪ Feature extraction : 40-coef MFCC  → PCA → 10 values
▪ Encoding           : angle encoding (θ = feature_i) onto 10 qubits
▪ Noise mitigation   : Richardson extrapolation with λ = 3
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import os, pathlib, random, warnings
import numpy as np
import librosa

# HuggingFace dataset (keeps the example self-contained)
from datasets import load_dataset

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory

SEED = 123
np.random.seed(SEED)
random.seed(SEED)

# ---------------------------------------------------------------
# 1. Dataset loading & MFCC feature extraction
# ---------------------------------------------------------------
print("⇨ Downloading / preparing UrbanSound8K …")
ds = load_dataset("urbansound8k", "audio", split="train")  # all 8 732 clips

def mfcc_mean(audio, sr=16_000, n_mfcc=40):
    """Return time-averaged MFCCs (len = n_mfcc)."""
    if audio.shape[-1] != sr:                               # resample if needed
        audio = librosa.resample(audio, orig_sr=audio.shape[-1], target_sr=sr)
    mfcc = librosa.feature.mfcc(y=audio.astype(np.float32), sr=sr,
                                n_mfcc=n_mfcc, hop_length=512)
    return mfcc.mean(axis=1)                                # average over time

# Extract MFCC features – this may take a minute
mfcc_features, labels = [], []
for sample in ds:
    audio = sample["audio"]["array"]
    mfcc_features.append(mfcc_mean(audio))
    labels.append(sample["classID"])

mfcc_features = np.stack(mfcc_features).astype(np.float32)
labels        = np.array(labels, dtype=int)

# ---------------------------------------------------------------
# 2. Dimensionality reduction → NUM_QUBITS principal components
# ---------------------------------------------------------------
NUM_QUBITS = 10
pca = PCA(n_components=NUM_QUBITS, random_state=SEED)
X_reduced = pca.fit_transform(mfcc_features)

# Scale each component into [0, π] for angle encoding
X_scaled = np.pi * (X_reduced - X_reduced.min()) / (X_reduced.max() - X_reduced.min() + 1e-12)

X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, labels, test_size=0.2, stratify=labels, random_state=SEED
)

# ---------------------------------------------------------------
# 3. Build & train the Variational Quantum Classifier (teacher model)
# ---------------------------------------------------------------
def feature_map(x):
    from qiskit import QuantumCircuit
    qc = QuantumCircuit(NUM_QUBITS)
    for i, theta in enumerate(x):
        qc.ry(theta, i)
    return qc

ansatz = EfficientSU2(NUM_QUBITS, reps=2)
qnn = SamplerQNN(
    circuit=ansatz,
    input_params=ansatz.parameters[:NUM_QUBITS],
    weight_params=ansatz.parameters[NUM_QUBITS:],
)

vqc = VQC(
    feature_map=feature_map,
    ansatz=ansatz,
    optimizer="COBYLA",
    quantum_instance=AerSimulator(seed_simulator=SEED),
    num_classes=10,                              # UrbanSound8K has 10 classes
)

print("⇨ Training VQC (≈ few minutes on CPU) …")
vqc.fit(X_train, y_train)

# ---------------------------------------------------------------
# 4. ZNE helper functions  (unchanged)
# ---------------------------------------------------------------
def build_noise_model(base_eps: float):
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ['x', 'rx', 'ry', 'rz', 'u1', 'u2', 'u3']
    )
    nm.add_all_qubit_quantum_error(depolarizing_error(base_eps, 2), ['cx'])
    return nm

def make_executor(base_eps: float):
    backend = AerSimulator(noise_model=build_noise_model(base_eps),
                           seed_simulator=SEED)
    qinst   = QuantumInstance(backend=backend, shots=1024, seed_transpiler=SEED)
    def _executor(circuit):
        cts = qinst.execute(circuit).get_counts()
        return cts.get('0' * circuit.num_qubits, 0) / 1024
    return _executor

def zne_predict(vqc, data, base_eps: float, λ: int = 3):
    executor     = make_executor(base_eps)
    fold_global  = folding.fold_global
    factory      = RichardsonFactory(scale_factors=[1, λ], order=1)
    y_hat = []
    for x in data:
        circuits = vqc._neural_network.construct_circuit(x)
        probs = []
        for circ in circuits:
            p0 = zne.execute_with_zne(circ, executor,
                                      scale_noise=fold_global,
                                      factory=factory)
            p0 = np.clip(p0, 1e-12, 1 - 1e-12)
            probs.append(p0)
        energies = [(1 - p) / 2 for p in probs]
        y_hat.append(np.argmin(energies))
    return np.array(y_hat)

# ---------------------------------------------------------------
# 5. Evaluate ZNE vs. raw noisy backend
# ---------------------------------------------------------------
noise_grid = [0.001, 0.01, 0.1]
print("\n=== Zero-Noise Extrapolation Results (UrbanSound8K) ===")
for eps in noise_grid:
    preds = zne_predict(vqc, X_test, base_eps=eps, λ=3)
    acc   = (preds == y_test).mean()
    print(f"ε = {eps:<6} → ZNE-mitigated test accuracy = {acc:.3f}")

def raw_accuracy(eps):
    backend = AerSimulator(noise_model=build_noise_model(eps),
                           seed_simulator=SEED)
    vqc.quantum_instance = QuantumInstance(backend=backend,
                                           seed_simulator=SEED,
                                           seed_transpiler=SEED)
    return vqc.score(X_test, y_test)

print("\n--- Without mitigation ---")
for eps in noise_grid:
    print(f"ε = {eps:<6} → Raw noisy accuracy          = {raw_accuracy(eps):.3f}")
