#python readout_effect_separate.py --dataset cifar --n_wires 8 --layers 2 --eps_max 0.15

#python readout_effect_separate.py --dataset eurosat \
#   --n_wires 8 --layers 2 --eps_max 0.15 \
#   --eurosat_csv      ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
#   --eurosat_rgb_root ~/quantum/rem/eurosat/EuroSAT_extracted \
#   --eurosat_tif_root ~/quantum/rem/eurosat/EuroSAT_extracted/allBands



#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
readout_effect_separate.py
--------------------------
Plot the effect of READOUT (assignment) noise **separately** for CIFAR-10 or EuroSAT.
Saves two PDFs per dataset (TVD vs epsilon, and |<Z0> error| vs epsilon) with 16pt fonts.

Usage examples:
  # CIFAR-10 (downloads via torchvision cache)
  python readout_effect_separate.py --dataset cifar --n_wires 8 --layers 2 --eps_max 0.15

  # EuroSAT (your CSV + roots)
  python readout_effect_separate.py --dataset eurosat --n_wires 8 --layers 2 --eps_max 0.15 \
    --eurosat_csv ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
    --eurosat_rgb_root  ~/quantum/rem/eurosat/EuroSAT_extracted \
    --eurosat_tif_root  ~/quantum/rem/eurosat/EuroSAT_extracted/allBands
"""

import argparse, sys, os
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import pennylane as qml

# ---------- Utilities ----------

def confusion_1q(e0, e1):
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=float)

def kronN(mats):
    A = np.array([[1.0]], dtype=float)
    for M in mats:
        A = np.kron(A, M)
    return A

def assignment_matrix(n_wires, e0, e1):
    Ms = [confusion_1q(e0, e1) for _ in range(n_wires)]
    return kronN(Ms)

def z_expect_from_probs(p, n_wires, wire=0):
    val = 0.0
    for idx, px in enumerate(p):
        bit = (idx >> (n_wires-1-wire)) & 1
        val += px * (1.0 if bit == 0 else -1.0)
    return val

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev)
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def angle_map(x):
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

def _pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([s.mean() for s in splits], dtype=float)

# ---------- CIFAR helper ----------
def get_cifar_vec(n_wires, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception as e:
        print("[readout_effect] torchvision not available; cannot fetch CIFAR-10.", file=sys.stderr)
        return None
    try:
        ds = datasets.CIFAR10(root=root, train=True, download=True, transform=None)
    except Exception as e:
        try:
            ds = datasets.CIFAR10(root=root, train=True, download=False, transform=None)
        except Exception as e2:
            print(f"[readout_effect] CIFAR failure: {e2}", file=sys.stderr)
            return None
    img, _ = ds[0]
    arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
    return _pool_to_n(arr, n_wires)

# ---------- EuroSAT helper ----------
def get_eurosat_vec(n_wires, csv_path, rgb_root, tif_root):
    try:
        import pandas as pd
        import rasterio
    except Exception as e:
        print("[readout_effect] pandas/rasterio missing; cannot read EuroSAT.", file=sys.stderr)
        return None
    df = pd.read_csv(csv_path)
    if len(df) == 0:
        print("[readout_effect] EuroSAT CSV has no rows.", file=sys.stderr)
        return None
    row = df.iloc[0]
    rgb_path = os.path.join(rgb_root, row["Filename"])
    tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
    if not (os.path.isfile(rgb_path) and os.path.isfile(tif_path)):
        print(f"[readout_effect] Missing files:\n  RGB: {rgb_path}\n  TIF: {tif_path}", file=sys.stderr)
        return None
    with rasterio.open(tif_path) as ds:
        tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
    band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)  # per-band
    vec = band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")
    return vec

# ---------- Sweep ----------
def sweep_readout(theta, n_wires, n_layers, eps_max=0.15, steps=16, seed=7):
    rng = np.random.default_rng(seed)
    weights = rng.standard_normal((n_layers, n_wires, 3)) * 0.3
    probs_fn = build_probs_fn(n_wires, n_layers)
    p_true = probs_fn(theta, weights)
    E = np.linspace(0.0, eps_max, steps)
    tvd_vals, z0_err = [], []
    for e in E:
        A = assignment_matrix(n_wires, e0=e*0.9, e1=e*1.1)  # asymmetric
        p_meas = A @ p_true
        p_meas = p_meas / p_meas.sum()
        tvd_vals.append(0.5 * np.abs(p_meas - p_true).sum())
        z_true = z_expect_from_probs(p_true, n_wires, wire=0)
        z_meas = z_expect_from_probs(p_meas, n_wires, wire=0)
        z0_err.append(abs(z_meas - z_true))
    return E, np.array(tvd_vals), np.array(z0_err)

# ---------- Main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True, help="Which dataset to use.")
    ap.add_argument("--n_wires", type=int, default=8, help="Number of qubits/features")
    ap.add_argument("--layers", type=int, default=2, help="VQC depth (StronglyEntanglingLayers)")
    ap.add_argument("--eps_max", type=float, default=0.15, help="Max epsilon for readout misclassification rate")
    ap.add_argument("--steps", type=int, default=16, help="Number of epsilon points in sweep")
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar", help="Torchvision cache/root for CIFAR")
    ap.add_argument("--eurosat_csv", type=str, default="", help="EuroSAT CSV (train.csv)")
    ap.add_argument("--eurosat_rgb_root", type=str, default="", help="EuroSAT RGB root")
    ap.add_argument("--eurosat_tif_root", type=str, default="", help="EuroSAT TIF root")
    args = ap.parse_args()

    # fonts
    plt.rcParams.update({"font.size": 16})

    if args.dataset == "cifar":
        marker = "."
        vec = get_cifar_vec(args.n_wires, root=args.cifar_root)
        if vec is None:
            print("[readout_effect] CIFAR vector unavailable.", file=sys.stderr); sys.exit(1)
        theta = angle_map(vec)
        E, TVD, Zerr = sweep_readout(theta, args.n_wires, args.layers, args.eps_max, args.steps)
        # Plot TVD
        plt.figure()
        plt.plot(E, TVD, marker=marker)
        plt.xlabel("epsilon (readout misclassification rate)")
        plt.ylabel("Total Variation Distance (TVD)")
        # plt.title("Readout Noise (CIFAR-10) — Histogram Distortion")
        out1 = Path("readout_noise_effect_cifar_tvd.pdf")
        plt.tight_layout(); plt.savefig(out1, bbox_inches="tight")
        print(f"[readout_effect] Wrote {out1}")
        # Plot Z error
        plt.figure()
        plt.plot(E, Zerr, marker=marker)
        plt.xlabel("epsilon (readout misclassification rate)")
        plt.ylabel("|<Z_0>_meas - <Z_0>_true|")
        # plt.title("Readout Noise (CIFAR-10) — Observable Error")
        out2 = Path("readout_noise_effect_cifar_z0err.pdf")
        plt.tight_layout(); plt.savefig(out2, bbox_inches="tight")
        print(f"[readout_effect] Wrote {out2}")
        return

    # EuroSAT
    marker = "|"
    if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
        print("[readout_effect] Provide --eurosat_csv, --eurosat_rgb_root, --eurosat_tif_root for EuroSAT.", file=sys.stderr)
        sys.exit(1)
    vec = get_eurosat_vec(args.n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root)
    if vec is None:
        print("[readout_effect] EuroSAT vector unavailable.", file=sys.stderr); sys.exit(1)
    theta = angle_map(vec)
    E, TVD, Zerr = sweep_readout(theta, args.n_wires, args.layers, args.eps_max, args.steps)
    # Plot TVD
    plt.figure()
    plt.plot(E, TVD, marker=marker)
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("Total Variation Distance (TVD)")
    # plt.title("Readout Noise (EuroSAT) — Histogram Distortion")
    out1 = Path("readout_noise_effect_eurosat_tvd.pdf")
    plt.tight_layout(); plt.savefig(out1, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out1}")
    # Plot Z error
    plt.figure()
    plt.plot(E, Zerr, marker=marker)
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("|<Z_0>_meas - <Z_0>_true|")
    # plt.title("Readout Noise (EuroSAT) — Observable Error")
    out2 = Path("readout_noise_effect_eurosat_z0err.pdf")
    plt.tight_layout(); plt.savefig(out2, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out2}")

if __name__ == "__main__":
    main()
