# urbansound8k_znkd.py
#
# Full ZNKD pipeline for UrbanSound8K:
#  - Teacher: 10-qubit VQC trained on true class labels
#  - ZNE:    zero-noise energies per class (teacher)
#  - Targets: tanh-stabilized regression labels from ZNE-corrected energies
#  - Student: 6-qubit QNN trained via regression on those soft targets

import random
import numpy as np
import librosa
from datasets import load_dataset

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit.utils import QuantumInstance
from qiskit.circuit.library import EfficientSU2
from qiskit import QuantumCircuit

from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC, NeuralNetworkRegressor

from urbansound8k_zne_utils import zne_expectation_zero

SEED = 123
N_CLASSES = 10
TEACHER_QUBITS = 10
STUDENT_QUBITS = 6
MFCC_COEFFS = 40
PCA_COMPONENTS = TEACHER_QUBITS  # 10 -> 10 qubits


# =========================================================
# 1. Dataset loading & MFCC feature extraction + PCA
# =========================================================
def mfcc_mean(audio, sr=16_000, n_mfcc=MFCC_COEFFS):
    """Return time-averaged MFCC coefficients for a 1D audio array."""
    if audio.shape[-1] != sr:
        audio = librosa.resample(audio, orig_sr=audio.shape[-1], target_sr=sr)
    mfcc = librosa.feature.mfcc(
        y=audio.astype(np.float32),
        sr=sr,
        n_mfcc=n_mfcc,
        hop_length=512,
    )
    return mfcc.mean(axis=1)


def load_urbansound8k_pca(max_clips=None):
    """Load UrbanSound8K, extract MFCCs, apply PCA and angle encoding."""
    print("⇨ Downloading / preparing UrbanSound8K …")
    ds = load_dataset("urbansound8k", "audio", split="train")

    if max_clips is not None:
        ds = ds.shuffle(SEED).select(range(max_clips))

    features, labels = [], []
    for sample in ds:
        audio = sample["audio"]["array"]
        features.append(mfcc_mean(audio))
        labels.append(sample["classID"])

    features = np.stack(features).astype(np.float32)
    labels = np.array(labels, dtype=int)

    pca = PCA(n_components=PCA_COMPONENTS, random_state=SEED)
    X = pca.fit_transform(features)

    # Scale to [0, π] for angle encoding
    X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        labels,
        test_size=0.2,
        stratify=labels,
        random_state=SEED,
    )
    print(f"   • Train size: {X_train.shape[0]:,}")
    print(f"   • Test  size: {X_test.shape[0]:,}")
    return X_train, X_test, y_train, y_test


def make_feature_map(num_qubits):
    """Simple Ry angle encoding on the first num_qubits features."""
    def fm(x):
        qc = QuantumCircuit(num_qubits)
        for i, theta in enumerate(x[:num_qubits]):
            qc.ry(float(theta), i)
        return qc
    return fm


# =========================================================
# 2. Teacher VQC
# =========================================================
def build_teacher_vqc():
    ansatz = EfficientSU2(TEACHER_QUBITS, reps=2)
    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(
        backend=backend,
        seed_simulator=SEED,
        seed_transpiler=SEED,
    )

    vqc = VQC(
        feature_map=make_feature_map(TEACHER_QUBITS),
        ansatz=ansatz,
        optimizer="COBYLA",
        quantum_instance=qinst,
        num_classes=N_CLASSES,
    )
    return vqc


# =========================================================
# 3. Student QNN regressor (ZNKD student)
# =========================================================
def build_student_regressor():
    ansatz = EfficientSU2(STUDENT_QUBITS, reps=1)

    qnn = SamplerQNN(
        circuit=ansatz,
        input_params=ansatz.parameters[:STUDENT_QUBITS],
        weight_params=ansatz.parameters[STUDENT_QUBITS:],
        sparse=False,
    )

    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(
        backend=backend,
        seed_simulator=SEED,
        seed_transpiler=SEED,
    )

    regressor = NeuralNetworkRegressor(
        neural_network=qnn,
        loss="l2",
        optimizer="COBYLA",
        quantum_instance=qinst,
    )
    return regressor


# =========================================================
# 4. ZNE-based tanh targets
# =========================================================
def compute_zne_tanh_targets(vqc, X, base_eps=0.01, tau=1.0):
    """Compute tanh-stabilized ZNE energies for each training example."""
    all_targets = []
    for i, x in enumerate(X):
        if (i + 1) % 100 == 0:
            print(f"[ZNE targets] {i+1}/{len(X)} samples", flush=True)

        energies = zne_expectation_zero(vqc, x, base_eps=base_eps, seed=SEED)
        stabilized = np.tanh(energies / tau)
        all_targets.append(stabilized)

    return np.stack(all_targets, axis=0)


# =========================================================
# 5. Accuracy helpers
# =========================================================
def teacher_accuracy(vqc, X, y):
    y_hat = vqc.predict(X)
    return (y_hat == y).mean()


def student_accuracy(student_regressor, X, y):
    preds = student_regressor.predict(X)
    y_hat = np.argmin(preds, axis=1)  # lower energy => more confident
    return (y_hat == y).mean()


# =========================================================
# 6. Main ZNKD pipeline
# =========================================================
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    X_train, X_test, y_train, y_test = load_urbansound8k_pca(max_clips=None)

    print("⇨ Building teacher VQC …")
    teacher = build_teacher_vqc()

    print("⇨ Training teacher on true labels …")
    teacher.fit(X_train, y_train)

    print("\n=== Teacher accuracy (noiseless simulator) ===")
    acc_teacher = teacher_accuracy(teacher, X_test, y_test)
    print(f"Teacher: {acc_teacher:.3f}")

    print("\n⇨ Computing ZNE-based tanh targets for distillation …")
    zne_targets = compute_zne_tanh_targets(
        teacher, X_train, base_eps=0.01, tau=1.0
    )

    print("⇨ Building student QNN regressor …")
    student = build_student_regressor()

    print("⇨ Training student on ZNE-tanh targets …")
    student.fit(X_train, zne_targets)

    print("\n=== Distilled student accuracy ===")
    acc_student = student_accuracy(student, X_test, y_test)
    print(f"Student (ZNKD regression): {acc_student:.3f}")


if __name__ == "__main__":
    main()
