# python overlay_noisy_vs_noiseless.py --dataset cifar \
#   --classes 0 1 \
#   --max_train 1000 --max_val 400 \
#   --epochs 100 --n_wires 8 --layers 2 \
#   --shots 128 --epsilon 0.02

# python overlay_noisy_vs_noiseless.py --dataset eurosat \
#   --eurosat_csv      ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
#   --eurosat_tif_root ~/quantum/rem/eurosat/EuroSAT_extracted/allBands \
#   --max_train 800 --max_val 300 \
#   --epochs 100 --n_wires 8 --layers 2 \
#   --shots 128 --epsilon 0.02


import argparse, sys, os
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import pennylane as qml

# ---------- utils ----------
def softmax_batch(Z):
    Z = Z - Z.max(axis=1, keepdims=True)
    EZ = np.exp(Z)
    return EZ / EZ.sum(axis=1, keepdims=True)

def cross_entropy_batch(logits, y):
    P = softmax_batch(logits)
    return -np.log(P[(np.arange(len(y)), y.astype(int))] + 1e-12).mean()

def onehot(y, C):
    Y = np.zeros((len(y), C), dtype=np.float64)
    Y[np.arange(len(y)), y.astype(int)] = 1.0
    return Y

def confusion_1q(e0, e1):
    return np.array([[1-e0, e1],[e0, 1-e1]], dtype=np.float64)

def assignment_matrix(n_wires, e0, e1):
    A = np.array([[1.0]], dtype=np.float64)
    M = confusion_1q(e0, e1)
    for _ in range(n_wires):
        A = np.kron(A, M)
    return A

def angle_map(x):
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([float(s.mean()) for s in splits], dtype=np.float64)

# ---------- data loaders ----------
def load_cifar_subset(n_wires, classes=(0,1), max_train=1000, max_val=400, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception:
        print("[overlay] torchvision not available; cannot use CIFAR.", file=sys.stderr)
        sys.exit(1)
    train = datasets.CIFAR10(root=root, train=True,  download=True,  transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True,  transform=None)

    def pack(ds, wanted, limit):
        X, y = [], []
        for img, lab in ds:
            if lab in wanted:
                arr = np.asarray(img.convert("L"), dtype=float)/255.0
                X.append(pool_to_n(arr, n_wires)); y.append(wanted.index(lab))
                if limit and len(X)>=limit: break
        return np.stack(X), np.array(y, dtype=int)

    wanted = list(classes)
    Xtr, ytr = pack(train, wanted, max_train)
    Xva, yva = pack(test,  wanted, max_val)
    return (Xtr,ytr), (Xva,yva), len(wanted)

def load_eurosat_subset(n_wires, csv_path, tif_root, classes=None, max_train=800, max_val=300):
    try:
        import pandas as pd, rasterio
    except Exception:
        print("[overlay] pandas/rasterio required for EuroSAT.", file=sys.stderr)
        sys.exit(1)
    df = pd.read_csv(os.path.expanduser(csv_path))
    wanted = sorted(df["Label"].unique().tolist())[:2] if classes is None else list(classes)

    Xtr, ytr, Xva, yva = [], [], [], []
    per_tr = max(1, max_train//len(wanted))
    per_va = max(1, max_val  //len(wanted))
    c_tr = {c:0 for c in wanted}
    c_va = {c:0 for c in wanted}

    # train fill
    for _, row in df.iterrows():
        lab = int(row["Label"])
        if lab not in wanted: continue
        tif_path = os.path.join(os.path.expanduser(tif_root), row["Filename"].replace(".jpg",".tif"))
        if not os.path.isfile(tif_path): continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float)/10000.0
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means)>=n_wires else np.pad(band_means, (0, n_wires-len(band_means)), "edge")
        if c_tr[lab] < per_tr:
            Xtr.append(vec); ytr.append(wanted.index(lab)); c_tr[lab]+=1
        if all(c_tr[c]>=per_tr for c in wanted): break

    # val fill
    for _, row in df.iterrows():
        if len(Xva) >= per_va*len(wanted): break
        lab = int(row["Label"])
        if lab not in wanted: continue
        tif_path = os.path.join(os.path.expanduser(tif_root), row["Filename"].replace(".jpg",".tif"))
        if not os.path.isfile(tif_path): continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float)/10000.0
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means)>=n_wires else np.pad(band_means, (0, n_wires-len(band_means)), "edge")
        if c_va[lab] < per_va:
            Xva.append(vec); yva.append(wanted.index(lab)); c_va[lab]+=1

    if not Xtr or not Xva:
        print("[overlay] EuroSAT subset too small or paths incorrect.", file=sys.stderr); sys.exit(1)

    return (np.stack(Xtr), np.array(ytr,dtype=int)), (np.stack(Xva), np.array(yva,dtype=int)), len(wanted)

# ---------- quantum precompute ----------
def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev)
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def precompute_probs(X, n_wires, n_layers, seed=123):
    rng = np.random.default_rng(seed)
    weights = rng.standard_normal((n_layers, n_wires, 3))*0.2
    probs_fn = build_probs_fn(n_wires, n_layers)
    P = np.zeros((len(X), 2**n_wires), dtype=np.float64)
    for i, x in enumerate(X):
        theta = angle_map(x)
        P[i] = probs_fn(theta, weights)
    return P

# ---------- optimizer for head ----------
class Adam:
    def __init__(self, shape_W, shape_b, lr=5e-2, b1=0.9, b2=0.999, eps=1e-8):
        self.lr,self.b1,self.b2,self.eps = lr,b1,b2,eps
        self.mW = np.zeros(shape_W); self.vW = np.zeros(shape_W)
        self.mb = np.zeros(shape_b); self.vb = np.zeros(shape_b)
        self.t = 0
    def step(self, W,b,gW,gb):
        self.t += 1
        self.mW = self.b1*self.mW + (1-self.b1)*gW
        self.vW = self.b2*self.vW + (1-self.b2)*(gW*gW)
        self.mb = self.b1*self.mb + (1-self.b1)*gb
        self.vb = self.b2*self.vb + (1-self.b2)*(gb*gb)
        mWh = self.mW/(1-self.b1**self.t); vWh = self.vW/(1-self.b2**self.t)
        mbh = self.mb/(1-self.b1**self.t); vbh = self.vb/(1-self.b2**self.t)
        W -= self.lr*mWh/(np.sqrt(vWh)+self.eps)
        b -= self.lr*mbh/(np.sqrt(vbh)+self.eps)
        return W,b

# ---------- simulate & train ----------
def simulate_measurements(P_true, shots, A):
    if shots <= 0:
        Pm = P_true.copy()
    else:
        counts = np.array([np.random.multinomial(shots, p) for p in P_true], dtype=np.float64)
        P = counts / shots
        Pm = P @ A.T
        Pm /= Pm.sum(axis=1, keepdims=True)
    return Pm

def train_head(P_true_tr, ytr, P_true_va, yva, num_classes, shots, epsilon,
               epochs=100, batch=64, lr=5e-2, seed_init=0, fixed_orders=None):
    D = P_true_tr.shape[1]
    n_wires = int(np.log2(D))
    A = assignment_matrix(n_wires, 0.9*epsilon, 1.1*epsilon) if (shots>0 or epsilon>0) else np.eye(D)

    rng = np.random.default_rng(seed_init)
    W = rng.standard_normal((num_classes, D))*0.01
    b = np.zeros((num_classes,), dtype=np.float64)
    opt = Adam(W.shape, b.shape, lr=lr)

    acc_hist, loss_hist = [], []
    N = P_true_tr.shape[0]

    for ep in range(epochs):
        if fixed_orders is not None:
            idx_order = fixed_orders[ep]
        else:
            idx_order = np.arange(N); rng.shuffle(idx_order)

        for i in range(0, N, batch):
            idx = idx_order[i:i+batch]
            Pm = simulate_measurements(P_true_tr[idx], shots, A)  # (B,D)
            logits = Pm @ W.T + b                                # (B,C)
            P = softmax_batch(logits)
            Y = onehot(ytr[idx], num_classes)
            G = (P - Y) / max(1,len(idx))
            gW = G.T @ Pm
            gb = G.sum(axis=0)
            W,b = opt.step(W,b,gW,gb)

        # validation (fresh draw for noisy; identity for noiseless)
        Pm_va = simulate_measurements(P_true_va, shots, A)
        logits_va = Pm_va @ W.T + b
        loss = cross_entropy_batch(logits_va, yva)
        pred = logits_va.argmax(axis=1)
        acc = (pred == yva).mean()*100.0
        loss_hist.append(loss); acc_hist.append(acc)

    return acc_hist, loss_hist

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--n_wires", type=int, default=8)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--epsilon", type=float, default=0.02)
    ap.add_argument("--shots", type=int, default=128)
    # CIFAR
    ap.add_argument("--classes", type=int, nargs="+", default=[0,1])
    ap.add_argument("--max_train", type=int, default=1000)
    ap.add_argument("--max_val", type=int, default=400)
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    # EuroSAT
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    args = ap.parse_args()

    plt.rcParams.update({"font.size": 16})

    if args.dataset == "cifar":
        (Xtr,ytr),(Xva,yva),C = load_cifar_subset(args.n_wires, tuple(args.classes), args.max_train, args.max_val, root=args.cifar_root)
        out_prefix = "compare_cifar"
        marker = "o"
    else:
        if not (args.eurosat_csv and args.eurosat_tif_root):
            print("[overlay] EuroSAT requires --eurosat_csv and --eurosat_tif_root", file=sys.stderr)
            sys.exit(1)
        (Xtr,ytr),(Xva,yva),C = load_eurosat_subset(args.n_wires, args.eurosat_csv, args.eurosat_tif_root, classes=None, max_train=args.max_train, max_val=args.max_val)
        out_prefix = "compare_eurosat"
        marker = "|"

    # precompute exact probabilities
    P_true_tr = precompute_probs(Xtr, args.n_wires, args.layers, seed=123)
    P_true_va = precompute_probs(Xva, args.n_wires, args.layers, seed=123)

    # fixed epoch orders so both runs see identical minibatch permutations
    rng_order = np.random.default_rng(999)
    orders = []
    N = P_true_tr.shape[0]
    for _ in range(args.epochs):
        idx = np.arange(N); rng_order.shuffle(idx); orders.append(idx)

    # train noiseless (shots=0, eps=0) and noisy (shots=given, eps=given) with SAME init seed
    acc_noiseless, loss_noiseless = train_head(P_true_tr, ytr, P_true_va, yva, C,
                                               shots=0, epsilon=0.0,
                                               epochs=args.epochs, batch=args.batch, lr=5e-2,
                                               seed_init=2024, fixed_orders=orders)

    acc_noisy, loss_noisy = train_head(P_true_tr, ytr, P_true_va, yva, C,
                                       shots=args.shots, epsilon=args.epsilon,
                                       epochs=args.epochs, batch=args.batch, lr=5e-2,
                                       seed_init=2024, fixed_orders=orders)

    # plots (no titles)
    epochs = np.arange(1, args.epochs+1)

    # accuracy
    plt.figure()
    plt.plot(epochs, acc_noiseless, marker=marker, linestyle="-", label="noiseless")
    plt.plot(epochs, acc_noisy, marker=marker, linestyle="--", label=f"noisy (shots={args.shots}, ε={args.epsilon})")
    plt.xlabel("epoch"); plt.ylabel("accuracy (%)")
    plt.legend()
    plt.tight_layout(); plt.savefig(Path(f"{out_prefix}_acc.pdf"), bbox_inches="tight")
    print(f"[overlay] wrote {out_prefix}_acc.pdf")

    # loss
    plt.figure()
    plt.plot(epochs, loss_noiseless, marker=marker, linestyle="-", label="noiseless")
    plt.plot(epochs, loss_noisy, marker=marker, linestyle="--", label=f"noisy (shots={args.shots}, ε={args.epsilon})")
    plt.xlabel("epoch"); plt.ylabel("cross-entropy loss")
    plt.legend()
    plt.tight_layout(); plt.savefig(Path(f"{out_prefix}_loss.pdf"), bbox_inches="tight")
    print(f"[overlay] wrote {out_prefix}_loss.pdf")

if __name__ == "__main__":
    main()