#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Show the effect of READOUT (assignment) noise on quantum measurements.
- Loads PCA features from features_*.npz (created by prepare_data.py).
- Builds a tiny VQC (AngleEmbedding + entanglers).
- Computes the true probability vector p_true.
- Applies an assignment (confusion) matrix A(eps0, eps1) to simulate readout noise: p_meas = A @ p_true
- Sweeps epsilon in [0, 0.15] and plots:
    1) Total Variation Distance TVD(p_meas, p_true) vs epsilon
    2) Observable error |<Z_0>_meas - <Z_0>_true| vs epsilon
Outputs:
  readout_noise_effect_<name>_tvd_vs_eps.pdf
  readout_noise_effect_<name>_z0err_vs_eps.pdf
"""
import argparse, json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import pennylane as qml
from pennylane import numpy as pnp

def load_features(npz_path):
    D = np.load(npz_path, allow_pickle=True)
    X, y = D["X"], D["y"]
    return X.astype(np.float32), y.astype(np.int64)

def angle_map(x):
    # map R^d -> [0, pi]^d
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

def confusion_1q(e0, e1):
    # [[P(0->0), P(1->0)],
    #  [P(0->1), P(1->1)]]
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=np.float64)

def kronN(mats):
    A = np.array([[1.0]], dtype=np.float64)
    for M in mats:
        A = np.kron(A, M)
    return A

def assignment_matrix(n_wires, e0, e1):
    Ms = [confusion_1q(e0, e1) for _ in range(n_wires)]
    return kronN(Ms)  # shape (2^n, 2^n)

def z_expect_from_probs(p, n_wires, wire=0):
    # <Z_wire> = sum_x p(x) * (-1)^(bit_wire)
    val = 0.0
    for idx in range(len(p)):
        bit = (idx >> (n_wires-1-wire)) & 1
        val += p[idx] * (1.0 if bit == 0 else -1.0)
    return val

def build_circuit(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)

    @qml.qnode(dev)
    def probs_fn(theta, weights):
        # theta: (n_wires,), weights: (n_layers, n_wires, 3) for StronglyEntanglingLayers
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--feature_file", type=str, required=True, help="Path to features_train.npz or features_test.npz")
    ap.add_argument("--name", type=str, default="cifar", help="Name for outputs (cifar/eurosat)")
    ap.add_argument("--layers", type=int, default=2, help="VQC depth")
    ap.add_argument("--seed", type=int, default=7)
    ap.add_argument("--eps_max", type=float, default=0.15, help="Max epsilon for sweep (0..0.3 reasonable)")
    args = ap.parse_args()

    X, y = load_features(args.feature_file)
    n_wires = X.shape[1]
    theta0 = angle_map(X[0])  # single sample to illustrate effect
    rng = np.random.default_rng(args.seed)
    weights = rng.standard_normal((args.layers, n_wires, 3)) * 0.3

    probs_fn = build_circuit(n_wires, args.layers)
    p_true = probs_fn(theta0, weights)

    # Sweep readout noise
    E = np.linspace(0.0, args.eps_max, 16)
    tvd_vals, z0_err = [], []
    for e in E:
        A = assignment_matrix(n_wires, e0=e*0.9, e1=e*1.1)  # slight asymmetry
        p_meas = A @ p_true
        p_meas = p_meas / p_meas.sum()

        tvd = 0.5*np.abs(p_meas - p_true).sum()
        tvd_vals.append(tvd)

        z_true = z_expect_from_probs(p_true, n_wires, wire=0)
        z_meas = z_expect_from_probs(p_meas, n_wires, wire=0)
        z0_err.append(abs(z_meas - z_true))

    # --- Plot 1: TVD vs epsilon ---
    plt.figure()
    plt.plot(E, tvd_vals, marker="o")
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("Total Variation Distance (TVD)")
    plt.title("Effect of Readout Noise on Histogram (TVD)")
    out1 = Path(f"readout_noise_effect_{args.name}_tvd_vs_eps.pdf")
    plt.tight_layout(); plt.savefig(out1, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out1}")

    # --- Plot 2: |<Z_0>_meas - <Z_0>_true| vs epsilon ---
    plt.figure()
    plt.plot(E, z0_err, marker="o")
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("|<Z_0>_meas - <Z_0>_true|")
    plt.title("Effect of Readout Noise on Observable")
    out2 = Path(f"readout_noise_effect_{args.name}_z0err_vs_eps.pdf")
    plt.tight_layout(); plt.savefig(out2, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out2}")

if __name__ == "__main__":
    main()
