# =======================================================================
# 0. Imports
# =======================================================================
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory

SEED = 123  # reproducibility throughout

# =======================================================================
# 1. Data loading & preprocessing  (Fashion-MNIST → PCA → angle encoding)
# =======================================================================
# 1a. Load Fashion-MNIST (28×28 greyscale, 10 classes)
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

# 1b. Flatten to 784-dim vectors and normalise to [0,1]
X_train = X_train.reshape(-1, 28 * 28).astype(np.float32) / 255.0
X_test  = X_test.reshape(-1, 28 * 28).astype(np.float32) / 255.0

# 1c. Dimensionality reduction → 10 principal components (≈10 qubits)
NUM_QUBITS = 10
pca = PCA(n_components=NUM_QUBITS, random_state=SEED)
X_train = pca.fit_transform(X_train)
X_test  = pca.transform(X_test)

# 1d. Scale each component into [0, π] for angle encoding
X_train = np.pi * (X_train - X_train.min()) / (X_train.max() - X_train.min() + 1e-12)
X_test  = np.pi * (X_test  - X_train.min()) / (X_train.max() - X_train.min() + 1e-12)

# Optional stratified split (train already balanced; keep full train set)
train_features, test_features = X_train, X_test
train_labels,   test_labels   = y_train, y_test

# =======================================================================
# 2. Build & train the Variational Quantum Classifier (teacher model)
# =======================================================================
def feature_map(x):
    """Simple Ry angle encoding for the first NUM_QUBITS PCA coefficients."""
    from qiskit import QuantumCircuit
    qc = QuantumCircuit(NUM_QUBITS)
    for i, val in enumerate(x):
        qc.ry(val, i)
    return qc

ansatz = EfficientSU2(NUM_QUBITS, reps=2)

qnn = SamplerQNN(
    circuit=ansatz,
    input_params=ansatz.parameters[:NUM_QUBITS],   # data parameters θ_i
    weight_params=ansatz.parameters[NUM_QUBITS:],  # trainable weights
)

vqc = VQC(
    feature_map=feature_map,        # callable → circuit per sample
    ansatz=ansatz,
    optimizer="COBYLA",
    quantum_instance=AerSimulator(seed_simulator=SEED),
    num_classes=10,                 # Fashion-MNIST has 10 labels
)

print("Training VQC … (this may take a few minutes)")
vqc.fit(train_features, train_labels)

# =======================================================================
# 3. ZNE helpers  (identical to your BibTeX script)
# =======================================================================
def build_noise_model(base_eps: float):
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(
        depolarizing_error(base_eps, 1),
        ['x', 'rx', 'ry', 'rz', 'u1', 'u2', 'u3']
    )
    nm.add_all_qubit_quantum_error(depolarizing_error(base_eps, 2), ['cx'])
    return nm

def make_executor(base_eps: float):
    backend   = AerSimulator(noise_model=build_noise_model(base_eps),
                             seed_simulator=SEED)
    qinst     = QuantumInstance(backend=backend, shots=1024, seed_transpiler=SEED)

    def _executor(circuit):
        counts = qinst.execute(circuit).get_counts()
        return counts.get('0' * circuit.num_qubits, 0) / 1024
    return _executor

def zne_predict(vqc, data, base_eps: float, λ: int = 3):
    executor     = make_executor(base_eps)
    fold_global  = folding.fold_global
    factory      = RichardsonFactory(scale_factors=[1, λ], order=1)
    predictions  = []

    for x in data:
        qc_list = vqc._neural_network.construct_circuit(x)
        probs = []
        for circ in qc_list:
            ezero = zne.execute_with_zne(
                circ, executor, scale_noise=fold_global, factory=factory
            )
            probs.append(np.clip(ezero, 1e-12, 1 - 1e-12))
        energies = [(1 - p) / 2 for p in probs]
        predictions.append(np.argmin(energies))
    return np.array(predictions)

# =======================================================================
# 4. Evaluate ZNE vs. raw noise  (ε ∈ {0.001, 0.01, 0.1})
# =======================================================================
noise_grid = [0.001, 0.01, 0.1]
print("\n=== Zero-Noise Extrapolation Results (Fashion-MNIST) ===")

for eps in noise_grid:
    y_hat = zne_predict(vqc, test_features, base_eps=eps, λ=3)
    acc   = (y_hat == test_labels).mean()
    print(f"ε = {eps:<6} → ZNE-mitigated test accuracy = {acc:.3f}")

def raw_accuracy(eps):
    backend = AerSimulator(noise_model=build_noise_model(eps),
                           seed_simulator=SEED)
    vqc.quantum_instance = QuantumInstance(
        backend=backend, seed_simulator=SEED, seed_transpiler=SEED
    )
    return vqc.score(test_features, test_labels)

print("\n--- Without mitigation ---")
for eps in noise_grid:
    print(f"ε = {eps:<6} → Raw noisy accuracy          = {raw_accuracy(eps):.3f}")
