"""
wine_quality_zne.py
-------------------
Zero-Noise Extrapolation with a Variational Quantum Classifier (VQC)
on the UCI Wine-Quality dataset (red + white wines, 6 quality classes).

▪ Features            : 11 physicochemical measurements
▪ Pre-processing       : standardise → PCA(10) → scale to [0, π]
▪ Encoding             : angle encoding (θ_i) on 10 qubits
▪ Teacher VQC          : 10-qubit EfficientSU2 (reps=2)
▪ Noise mitigation     : Richardson extrapolation, λ = 3
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import urllib.request, io, zipfile, pandas as pd, numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance
from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory

SEED = 123

# ---------------------------------------------------------------
# 1. Load & combine red + white wine data
#    (automatically downloads the 2 CSVs from the UCI archive)
# ---------------------------------------------------------------
print("⇨ Downloading UCI wine-quality CSVs …")
URL_ZIP = (
    "https://archive.ics.uci.edu/ml/machine-learning-databases/"
    "wine-quality/winequality-red.csv"
)
# red
red = pd.read_csv(URL_ZIP, sep=";")
# white
white = pd.read_csv(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/"
    "wine-quality/winequality-white.csv",
    sep=";",
)
data = pd.concat([red, white], ignore_index=True)

X = data.drop(columns=["quality"]).values.astype(np.float32)
y = data["quality"].values.astype(int)  # 6 classes (3‒8)

# ---------------------------------------------------------------
# 2. Standardise → PCA → scale to angle-encoding range
# ---------------------------------------------------------------
X = StandardScaler().fit_transform(X)

NUM_QUBITS = 10
pca = PCA(n_components=NUM_QUBITS, random_state=SEED)
X = pca.fit_transform(X)

# map each component to [0, π]  (angle encoding)
X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

# stratified split (80 / 20)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=SEED
)

# ---------------------------------------------------------------
# 3. Build & train the VQC (teacher model)
# ---------------------------------------------------------------
def feature_map(x):
    from qiskit import QuantumCircuit
    qc = QuantumCircuit(NUM_QUBITS)
    for i, theta in enumerate(x):
        qc.ry(theta, i)
    return qc


ansatz = EfficientSU2(NUM_QUBITS, reps=2)
qnn = SamplerQNN(
    circuit=ansatz,
    input_params=ansatz.parameters[:NUM_QUBITS],
    weight_params=ansatz.parameters[NUM_QUBITS:],
)

vqc = VQC(
    feature_map=feature_map,
    ansatz=ansatz,
    optimizer="COBYLA",
    quantum_instance=AerSimulator(seed_simulator=SEED),
    num_classes=len(np.unique(y)),
)

print("⇨ Training VQC …")
vqc.fit(X_train, y_train)

# ---------------------------------------------------------------
# 4. ZNE helpers  (unchanged from earlier scripts)
# ---------------------------------------------------------------
def build_noise_model(base_eps: float):
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ["x", "rx", "ry", "rz", "u1", "u2", "u3"],
    )
    nm.add_all_qubit_quantum_error(depolarizing_error(base_eps, 2), ["cx"])
    return nm


def make_executor(base_eps: float):
    backend = AerSimulator(noise_model=build_noise_model(base_eps), seed_simulator=SEED)
    qinst = QuantumInstance(backend=backend, shots=1024, seed_transpiler=SEED)

    def _executor(circ):
        cts = qinst.execute(circ).get_counts()
        return cts.get("0" * circ.num_qubits, 0) / 1024

    return _executor


def zne_predict(vqc, data, base_eps: float, λ: int = 3):
    executor, fold_global = make_executor(base_eps), folding.fold_global
    factory = RichardsonFactory(scale_factors=[1, λ], order=1)

    preds = []
    for x in data:
        circuits = vqc._neural_network.construct_circuit(x)
        energies = []
        for circ in circuits:
            p0 = zne.execute_with_zne(circ, executor, fold_global, factory)
            p0 = np.clip(p0, 1e-12, 1 - 1e-12)
            energies.append((1 - p0) / 2)
        preds.append(np.argmin(energies))
    return np.array(preds)


# ---------------------------------------------------------------
# 5. Evaluate ZNE vs. raw noise
# ---------------------------------------------------------------
noise_grid = [0.001, 0.01, 0.1]
print("\n=== Zero-Noise Extrapolation Results (Wine-Quality) ===")
for eps in noise_grid:
    y_hat = zne_predict(vqc, X_test, base_eps=eps, λ=3)
    acc = (y_hat == y_test).mean()
    print(f"ε = {eps:<6} → ZNE-mitigated test accuracy = {acc:.3f}")


def raw_accuracy(eps):
    backend = AerSimulator(noise_model=build_noise_model(eps), seed_simulator=SEED)
    vqc.quantum_instance = QuantumInstance(
        backend=backend, seed_simulator=SEED, seed_transpiler=SEED
    )
    return vqc.score(X_test, y_test)


print("\n--- Without mitigation ---")
for eps in noise_grid:
    print(f"ε = {eps:<6} → Raw noisy accuracy          = {raw_accuracy(eps):.3f}")
