# ---------------------------------------------------------------
# 0. Imports (same as before + text utilities)
# ---------------------------------------------------------------
import numpy as np
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from qiskit.circuit.library import TwoLocal
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC

from mitiq import zne, folding
from mitiq.zne.inference import RichardsonFactory
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.utils import QuantumInstance

SEED = 123  # reuse all seeds for reproducibility

# ---------------------------------------------------------------
# 1. Data loading & preprocessing  (AG News → TF-IDF → numpy)
# ---------------------------------------------------------------
ds = load_dataset("ag_news")            # train / test splits
texts = ds['train']['text'] + ds['test']['text']
labels = ds['train']['label'] + ds['test']['label']

# TF-IDF with a small vocabulary so we can keep ≤10 qubits
VECTOR_DIM = 1024                       # matches qubit count
tfidf = TfidfVectorizer(max_features=VECTOR_DIM, stop_words="english")
X = tfidf.fit_transform(texts).toarray().astype(np.float32)
y = np.array(labels, dtype=int)

# Optional: scale to [0, π] for simple angle encoding
X = np.pi * X / (X.max() + 1e-12)

# train / test split (stratified)
train_features, test_features, train_labels, test_labels = train_test_split(
    X, y, test_size=0.2, random_state=SEED, stratify=y
)

# ---------------------------------------------------------------
# 2. VQC definition (angle-encode 1024 features on 10 qubits)
# ---------------------------------------------------------------
from qiskit.circuit.library import EfficientSU2
from qiskit import QuantumCircuit, ClassicalRegister

NUM_QUBITS = 10                        # balance between features & depth

# 2a. Feature map: simple angle encoding (θ = x_i) in blocks of 10 features
def feature_map(x):
    qc = QuantumCircuit(NUM_QUBITS)
    for i, val in enumerate(x[:NUM_QUBITS]):          # one block
        qc.ry(val, i)
    return qc

# 2b. Ansatz: depth-2 Hardware Efficient SU(2)
ansatz = EfficientSU2(NUM_QUBITS, reps=2)

# 2c. Build a SamplerQNN backed by Aer
qnn = SamplerQNN(
    circuit=ansatz,
    input_params=ansatz.parameters[:NUM_QUBITS],      # θ’s from feature map
    weight_params=ansatz.parameters[NUM_QUBITS:],     # trainable params
)

vqc = VQC(
    feature_map=feature_map,
    ansatz=ansatz,
    optimizer="COBYLA",
    quantum_instance=AerSimulator(seed_simulator=SEED),
    num_classes=4,                                    # AG News has 4 labels
)

# ---------------------------------------------------------------
# 3. Train the VQC (no noise, quick demo)
# ---------------------------------------------------------------
vqc.fit(train_features, train_labels)

# ---------------------------------------------------------------
# 4. ZNE helper functions (UNCHANGED from your BibTeX code)
# ---------------------------------------------------------------
def build_noise_model(base_eps: float):
    nm = NoiseModel()
    nm.add_all_qubit_quantum_error(depolarizing_error(base_eps, 1),
                                   ['x','rx','ry','rz','u1','u2','u3'])
    nm.add_all_qubit_quantum_error(depolarizing_error(base_eps, 2), ['cx'])
    return nm

def make_executor(base_eps: float):
    backend   = AerSimulator(noise_model=build_noise_model(base_eps),
                             seed_simulator=SEED)
    qinst     = QuantumInstance(backend=backend, shots=1024, seed_transpiler=SEED)
    def _executor(circuit):
        counts  = qinst.execute(circuit).get_counts()
        return counts.get('0'*circuit.num_qubits, 0) / 1024
    return _executor

def zne_predict(vqc, data, base_eps: float, λ: int = 3):
    executor = make_executor(base_eps)
    fold_global = folding.fold_global
    factory     = RichardsonFactory(scale_factors=[1, λ], order=1)

    preds = []
    for x in data:
        qc_list = vqc._neural_network.construct_circuit(x)
        probs = []
        for circ in qc_list:
            ezero = zne.execute_with_zne(
                circ, executor, scale_noise=fold_global, factory=factory
            )
            probs.append(np.clip(ezero, 1e-12, 1 - 1e-12))
        energy = [(1 - p)/2 for p in probs]
        preds.append(np.argmin(energy))
    return np.array(preds)

# ---------------------------------------------------------------
# 5. Evaluate ZNE vs. raw noise   (ε = 0.001, 0.01, 0.1)
# ---------------------------------------------------------------
noise_grid = [0.001, 0.01, 0.1]
mitigated_results = {}

for eps in noise_grid:
    y_hat = zne_predict(vqc, test_features, base_eps=eps, λ=3)
    acc   = (y_hat == test_labels).mean()
    mitigated_results[eps] = acc
    print(f"ε = {eps:<6} → ZNE-mitigated test accuracy = {acc:.3f}")

def raw_accuracy(eps):
    backend = AerSimulator(noise_model=build_noise_model(eps),
                           seed_simulator=SEED)
    vqc.quantum_instance = QuantumInstance(backend=backend,
                                           seed_simulator=SEED,
                                           seed_transpiler=SEED)
    return vqc.score(test_features, test_labels)

for eps in noise_grid:
    print(f"ε = {eps:<6} → No mitigation accuracy      = {raw_accuracy(eps):.3f}")
