# python train_readout_accuracy_loss.py --dataset cifar \
#   --classes 0 1 \
#   --max_train 1000 --max_val 400 \
#   --epochs 10 --n_wires 8 --layers 2 \
#   --noise_levels 0.0 0.02 0.05 0.10 0.15

# python train_readout_accuracy_loss.py --dataset eurosat \
#   --eurosat_csv      ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
#   --eurosat_rgb_root ~/quantum/rem/eurosat/EuroSAT_extracted \
#   --eurosat_tif_root ~/quantum/rem/eurosat/EuroSAT_extracted/allBands \
#   --max_train 800 --max_val 300 \
#   --epochs 8 --n_wires 8 --layers 2 \
#   --noise_levels 0.0 0.02 0.05 0.10


import argparse, sys, os
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import pennylane as qml
from pennylane import numpy as pnp

# ---------------- Readout / math utils ----------------

def confusion_1q(e0, e1):
    return pnp.array([[1-e0, e1],
                      [e0,   1-e1]], dtype=pnp.float64)

def assignment_matrix(n_wires, e0, e1):
    A = pnp.array([[1.0]], dtype=pnp.float64)
    M = confusion_1q(e0, e1)
    for _ in range(n_wires):
        A = pnp.kron(A, M)
    return A  # (2^n, 2^n)

def softmax(z):
    z = z - pnp.max(z)
    ez = pnp.exp(z)
    return ez / pnp.sum(ez)

def xent_from_logits(logits, y):
    p = softmax(logits)
    return -pnp.log(p[int(y)] + 1e-12)

def angle_map(x):
    return pnp.pi * (1.0 / (1.0 + pnp.exp(-x)))  # sigmoid * pi

# ---------------- Circuit ----------------

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev, interface="autograd")
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

# ---------------- Feature helpers ----------------

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return pnp.array([float(s.mean()) for s in splits], dtype=pnp.float64)

def load_cifar_subset(n_wires, classes=(0,1), max_train=1000, max_val=400, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception as e:
        print("[train] torchvision not available; cannot use CIFAR.", file=sys.stderr)
        sys.exit(1)
    train = datasets.CIFAR10(root=root, train=True,  download=True,  transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True,  transform=None)

    def filter_pack(ds, wanted, limit):
        X, y = [], []
        for img, label in ds:
            if label in wanted:
                arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
                X.append(pool_to_n(arr, n_wires))
                y.append(wanted.index(label))  # remap to 0..C-1
                if limit and len(X) >= limit: break
        return pnp.stack(X, axis=0), pnp.array(y, dtype=int)

    wanted = list(classes)
    Xtr, ytr = filter_pack(train, wanted, max_train)
    Xva, yva = filter_pack(test,  wanted, max_val)   # use test as "val" for simplicity
    num_classes = len(wanted)
    return (Xtr, ytr), (Xva, yva), num_classes

def load_eurosat_subset(n_wires, csv_path, rgb_root, tif_root, classes=None, max_train=800, max_val=300):
    try:
        import pandas as pd, rasterio
    except Exception as e:
        print("[train] pandas/rasterio required for EuroSAT.", file=sys.stderr)
        sys.exit(1)
    df = pd.read_csv(csv_path)
    if classes is None:
        uniq = sorted(df["Label"].unique().tolist())
        classes = uniq[:2]
        print(f"[train] EuroSAT classes defaulted to: {classes}")
    # simple class-balanced take from top
    Xtr, ytr, Xva, yva = [], [], [], []
    per_class_tr = max(1, max_train // len(classes))
    per_class_va = max(1, max_val   // len(classes))
    counts_tr = {c:0 for c in classes}
    counts_va = {c:0 for c in classes}

    for _, row in df.iterrows():
        lab = int(row["Label"])
        if lab not in classes: continue
        rgb_path = os.path.join(rgb_root, row["Filename"])
        tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
        if not (os.path.isfile(rgb_path) and os.path.isfile(tif_path)):
            continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")
        vec = pnp.array(vec, dtype=pnp.float64)

        if counts_tr[lab] < per_class_tr:
            Xtr.append(vec); ytr.append(classes.index(lab)); counts_tr[lab]+=1
        elif counts_va[lab] < per_class_va:
            Xva.append(vec); yva.append(classes.index(lab)); counts_va[lab]+=1
        if all(counts_tr[c]>=per_class_tr for c in classes) and all(counts_va[c]>=per_class_va for c in classes):
            break

    if not Xtr or not Xva:
        print("[train] EuroSAT subset too small or paths wrong.", file=sys.stderr); sys.exit(1)

    Xtr = pnp.stack(Xtr, axis=0); ytr = pnp.array(ytr, dtype=int)
    Xva = pnp.stack(Xva, axis=0); yva = pnp.array(yva, dtype=int)
    num_classes = len(classes)
    return (Xtr, ytr), (Xva, yva), num_classes

# ---------------- Model / training ----------------

class VQCClassifier:
    def __init__(self, n_wires, n_layers, num_classes, seed=0):
        self.n_wires = n_wires
        self.n_layers = n_layers
        self.num_classes = num_classes
        self.dim = 2**n_wires
        rng = np.random.default_rng(seed)
        # Parameters with gradients
        self.weights = pnp.array(rng.standard_normal((n_layers, n_wires, 3))*0.2, requires_grad=True)
        self.W = pnp.array(rng.standard_normal((num_classes, self.dim))*0.01, requires_grad=True)
        self.b = pnp.array(np.zeros((num_classes,)), requires_grad=True)
        self.probs_fn = build_probs_fn(n_wires, n_layers)

    def forward_logits_with(self, theta, A, weights, W, b):
        p_true = self.probs_fn(theta, weights)
        p_meas = A @ p_true
        p_meas = p_meas / pnp.sum(p_meas)
        logits = W @ p_meas + b
        return logits

def train_one_noise(model, A, Xtr, ytr, Xva, yva, epochs=10, lr=1e-2, batch=32):
    opt = qml.AdamOptimizer(lr)
    N = Xtr.shape[0]

    def batch_slice(i):
        start = (i*batch) % N
        end = min(start+batch, N)
        return slice(start, end)

    acc_hist, loss_hist = [], []

    # objective that *uses* the trainable params passed in by the optimizer
    def batch_loss(weights, W, b, sl):
        loss = 0.0
        for xi, yi in zip(Xtr[sl], ytr[sl]):
            theta = angle_map(xi)
            logits = model.forward_logits_with(theta, A, weights, W, b)
            loss = loss + xent_from_logits(logits, yi)
        return loss / (sl.stop - sl.start)

    # initialize trainable parameters
    weights, W, b = model.weights, model.W, model.b

    for ep in range(epochs):
        batches = max(1, int(np.ceil(N / batch)))
        for i in range(batches):
            sl = batch_slice(i)
            # ⬇️ pass params as separate args (NOT one tuple)
            (weights, W, b), _ = opt.step_and_cost(
                lambda w, W_, b_: batch_loss(w, W_, b_, sl),
                weights, W, b
            )

        # evaluate on validation split with the updated params
        def eval_split(X, y):
            tot_loss, correct = 0.0, 0
            for xi, yi in zip(X, y):
                theta = angle_map(xi)
                logits = model.forward_logits_with(theta, A, weights, W, b)
                tot_loss += float(xent_from_logits(logits, yi))
                pred = int(pnp.argmax(softmax(logits)))
                correct += int(pred == int(yi))
            return tot_loss/len(X), 100.0*correct/len(X)

        val_loss, val_acc = eval_split(Xva, yva)
        loss_hist.append(val_loss); acc_hist.append(val_acc)
        print(f"  epoch {ep+1:02d}: val_acc={val_acc:5.2f}%  val_loss={val_loss:.4f}")

    # store back (optional)
    model.weights, model.W, model.b = weights, W, b
    return acc_hist, loss_hist

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--n_wires", type=int, default=8)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--epochs", type=int, default=10)
    ap.add_argument("--lr", type=float, default=1e-2)
    ap.add_argument("--batch", type=int, default=32)
    ap.add_argument("--noise_levels", type=float, nargs="+", default=[0.0, 0.02, 0.05, 0.10, 0.15])
    # CIFAR options
    ap.add_argument("--classes", type=int, nargs="+", default=[0,1], help="CIFAR class ids (e.g., 0 1)")
    ap.add_argument("--max_train", type=int, default=1000)
    ap.add_argument("--max_val", type=int, default=400)
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    # EuroSAT options
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    ap.add_argument("--eurosat_classes", type=int, nargs="+", default=None, help="EuroSAT label ids (e.g., 0 1)")
    args = ap.parse_args()

    plt.rcParams.update({"font.size": 16})

    # Load dataset
    if args.dataset == "cifar":
        (Xtr, ytr), (Xva, yva), num_classes = load_cifar_subset(
            args.n_wires, tuple(args.classes), args.max_train, args.max_val, root=args.cifar_root
        )
        out_prefix = "readout_train_cifar"
        marker_cycle = ["o","s","^","x","d","v"]
    else:
        if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
            print("[train] Provide EuroSAT paths: --eurosat_csv --eurosat_rgb_root --eurosat_tif_root", file=sys.stderr)
            sys.exit(1)
        (Xtr, ytr), (Xva, yva), num_classes = load_eurosat_subset(
            args.n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root,
            classes=args.eurosat_classes, max_train=args.max_train, max_val=args.max_val
        )
        out_prefix = "readout_train_eurosat"
        marker_cycle = ["|","_","1","2","3","4"]  # visually distinct markers

    # Train per noise level
    curves_acc = {}
    curves_loss = {}
    for i, e in enumerate(args.noise_levels):
        e0, e1 = 0.9*e, 1.1*e  # slight asymmetry
        A = assignment_matrix(args.n_wires, e0, e1)
        print(f"\n[train] Noise level ε={e:.3f} (e0={e0:.4f}, e1={e1:.4f})")
        model = VQCClassifier(args.n_wires, args.layers, num_classes, seed=42)  # re-init per noise
        acc_hist, loss_hist = train_one_noise(model, A, Xtr, ytr, Xva, yva,
                                              epochs=args.epochs, lr=args.lr, batch=args.batch)
        curves_acc[e] = acc_hist
        curves_loss[e] = loss_hist

    # Plot ACC (no title)
    plt.figure()
    for i, e in enumerate(args.noise_levels):
        m = marker_cycle[i % len(marker_cycle)]
        plt.plot(range(1, len(curves_acc[e])+1), curves_acc[e], marker=m, label=f"ε={e:g}")
    plt.xlabel("epoch")
    plt.ylabel("accuracy (%)")
    plt.legend()
    acc_out = Path(f"{out_prefix}_acc.pdf")
    plt.tight_layout(); plt.savefig(acc_out, bbox_inches="tight")
    print(f"[train] Wrote {acc_out}")

    # Plot LOSS (no title)
    plt.figure()
    for i, e in enumerate(args.noise_levels):
        m = marker_cycle[i % len(marker_cycle)]
        plt.plot(range(1, len(curves_loss[e])+1), curves_loss[e], marker=m, label=f"ε={e:g}")
    plt.xlabel("epoch")
    plt.ylabel("cross-entropy loss")
    plt.legend()
    loss_out = Path(f"{out_prefix}_loss.pdf")
    plt.tight_layout(); plt.savefig(loss_out, bbox_inches="tight")
    print(f"[train] Wrote {loss_out}")

if __name__ == "__main__":
    main()
PY
