import argparse, sys, os
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import pennylane as qml

# ---------------- Utilities ----------------

def softmax_batch(Z):
    # Z: (B, C)
    Z = Z - Z.max(axis=1, keepdims=True)
    EZ = np.exp(Z)
    return EZ / EZ.sum(axis=1, keepdims=True)

def cross_entropy_batch(logits, y):
    # logits: (B,C), y: (B,)
    P = softmax_batch(logits)
    idx = (np.arange(len(y)), y.astype(int))
    return -np.log(P[idx] + 1e-12).mean()

def onehot(y, C):
    Y = np.zeros((len(y), C), dtype=np.float64)
    Y[np.arange(len(y)), y.astype(int)] = 1.0
    return Y

def confusion_1q(e0, e1):
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=np.float64)

def assignment_matrix(n_wires, e0, e1):
    A = np.array([[1.0]], dtype=np.float64)
    M = confusion_1q(e0, e1)
    for _ in range(n_wires):
        A = np.kron(A, M)
    return A  # (2^n,2^n)

def angle_map(x):
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([float(s.mean()) for s in splits], dtype=np.float64)

# ---------------- Data loaders ----------------

def load_cifar_subset(n_wires, classes=(0,1), max_train=1000, max_val=400, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception as e:
        print("[fast] torchvision not available; cannot use CIFAR.", file=sys.stderr)
        sys.exit(1)
    train = datasets.CIFAR10(root=root, train=True,  download=True,  transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True,  transform=None)

    def filter_pack(ds, wanted, limit):
        X, y = [], []
        for img, label in ds:
            if label in wanted:
                arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
                X.append(pool_to_n(arr, n_wires))
                y.append(wanted.index(label))  # remap to 0..C-1
                if limit and len(X) >= limit: break
        return np.stack(X, axis=0), np.array(y, dtype=int)

    wanted = list(classes)
    Xtr, ytr = filter_pack(train, wanted, max_train)
    Xva, yva = filter_pack(test,  wanted, max_val)   # test as "val"
    num_classes = len(wanted)
    return (Xtr, ytr), (Xva, yva), num_classes

def load_eurosat_subset(n_wires, csv_path, rgb_root, tif_root, classes=None, max_train=800, max_val=300):
    try:
        import pandas as pd, rasterio
    except Exception as e:
        print("[fast] pandas/rasterio required for EuroSAT.", file=sys.stderr)
        sys.exit(1)
    df = pd.read_csv(csv_path)
    if classes is None:
        uniq = sorted(df["Label"].unique().tolist())
        classes = uniq[:2]
        print(f"[fast] EuroSAT classes defaulted to: {classes}")
    Xtr, ytr, Xva, yva = [], [], [], []
    per_class_tr = max(1, max_train // len(classes))
    per_class_va = max(1, max_val   // len(classes))
    counts_tr = {c:0 for c in classes}
    counts_va = {c:0 for c in classes}

    for _, row in df.iterrows():
        lab = int(row["Label"])
        if lab not in classes: continue
        rgb_path = os.path.join(rgb_root, row["Filename"])
        tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
        if not (os.path.isfile(rgb_path) and os.path.isfile(tif_path)):
            continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")
        Xtr.append(vec); ytr.append(classes.index(lab)); counts_tr[lab]+= (counts_tr[lab]<per_class_tr)
        if all(counts_tr[c]>=per_class_tr for c in classes) and all(counts_va[c]>=per_class_va for c in classes):
            break
    # If val empty, quickly fill
    for _, row in df.iterrows():
        if len(Xva) >= per_class_va*len(classes): break
        lab = int(row["Label"])
        if lab not in classes: continue
        tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
        if not os.path.isfile(tif_path): continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float)/10000.0
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")
        Xva.append(vec); yva.append(classes.index(lab))

    if not Xtr or not Xva:
        print("[fast] EuroSAT subset too small or paths wrong.", file=sys.stderr); sys.exit(1)

    return (np.stack(Xtr), np.array(ytr, dtype=int)), (np.stack(Xva), np.array(yva, dtype=int)), len(classes)

# ---------------- Quantum precompute ----------------

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev)
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def precompute_probs(X, n_wires, n_layers, seed=0):
    rng = np.random.default_rng(seed)
    weights = rng.standard_normal((n_layers, n_wires, 3))*0.2
    probs_fn = build_probs_fn(n_wires, n_layers)
    P = np.zeros((len(X), 2**n_wires), dtype=np.float64)
    for i, x in enumerate(X):
        theta = angle_map(x)
        P[i] = probs_fn(theta, weights)
    return P  # (N, D)

# ---------------- Head training (fast) ----------------

class Adam:
    def __init__(self, shape_W, shape_b, lr=1e-2, beta1=0.9, beta2=0.999, eps=1e-8):
        self.lr, self.b1, self.b2, self.eps = lr, beta1, beta2, eps
        self.mW = np.zeros(shape_W); self.vW = np.zeros(shape_W)
        self.mb = np.zeros(shape_b); self.vb = np.zeros(shape_b)
        self.t = 0

    def step(self, W, b, gW, gb):
        self.t += 1
        self.mW = self.b1*self.mW + (1-self.b1)*gW
        self.vW = self.b2*self.vW + (1-self.b2)*(gW*gW)
        self.mb = self.b1*self.mb + (1-self.b1)*gb
        self.vb = self.b2*self.vb + (1-self.b2)*(gb*gb)
        mWh = self.mW/(1-self.b1**self.t); vWh = self.vW/(1-self.b2**self.t)
        mbh = self.mb/(1-self.b1**self.t); vbh = self.vb/(1-self.b2**self.t)
        W -= self.lr * mWh/(np.sqrt(vWh)+self.eps)
        b -= self.lr * mbh/(np.sqrt(vbh)+self.eps)
        return W, b

def simulate_measurements(P_true, shots, A):
    """P_true: (N,D); shots int; A: (D,D) assignment; returns P_meas: (N,D)."""
    if shots == 0:
        P = P_true.copy()
    else:
        # Multinomial sampling per row
        counts = np.array([np.random.multinomial(shots, p) for p in P_true], dtype=np.float64)
        P = counts / shots
    # apply readout assignment
    Pm = P @ A.T
    Pm /= Pm.sum(axis=1, keepdims=True)
    return Pm

def train_head(P_true_tr, ytr, P_true_va, yva, num_classes, shots, epsilon, epochs=100, batch=64, lr=5e-2):
    N, D = P_true_tr.shape
    A = assignment_matrix(int(np.log2(D)), 0.9*epsilon, 1.1*epsilon)

    rng = np.random.default_rng(0)
    W = rng.standard_normal((num_classes, D))*0.01
    b = np.zeros((num_classes,), dtype=np.float64)
    opt = Adam(W.shape, b.shape, lr=lr)

    acc_hist, loss_hist = [], []

    order = np.arange(N)
    for ep in range(epochs):
        np.random.shuffle(order)
        for i in range(0, N, batch):
            idx = order[i:i+batch]
            Pm = simulate_measurements(P_true_tr[idx], shots, A)  # (B,D)
            logits = (Pm @ W.T) + b  # (B,C)
            # Gradients
            P = softmax_batch(logits)  # (B,C)
            Y = onehot(ytr[idx], num_classes)  # (B,C)
            G = (P - Y) / len(idx)  # (B,C)
            gW = G.T @ Pm  # (C,D)
            gb = G.sum(axis=0)  # (C,)
            W, b = opt.step(W, b, gW, gb)

        # Validation
        Pm_va = simulate_measurements(P_true_va, shots, A)
        logits_va = (Pm_va @ W.T) + b
        loss = cross_entropy_batch(logits_va, yva)
        pred = logits_va.argmax(axis=1)
        acc = (pred == yva).mean()*100.0
        loss_hist.append(loss); acc_hist.append(acc)
        print(f"  epoch {ep+1:03d}: val_acc={acc:5.2f}%  val_loss={loss:.4f}")

    return acc_hist, loss_hist

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--n_wires", type=int, default=8)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch", type=int, default=64)
    ap.add_argument("--shots_levels", type=int, nargs="+", default=[0, 128, 256, 512, 1024])
    ap.add_argument("--epsilon", type=float, default=0.0)
    # CIFAR options
    ap.add_argument("--classes", type=int, nargs="+", default=[0,1])
    ap.add_argument("--max_train", type=int, default=1000)
    ap.add_argument("--max_val", type=int, default=400)
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    # EuroSAT options
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    ap.add_argument("--eurosat_classes", type=int, nargs="+", default=None)
    args = ap.parse_args()

    plt.rcParams.update({"font.size": 16})

    # Load features
    if args.dataset == "cifar":
        (Xtr, ytr), (Xva, yva), num_classes = load_cifar_subset(
            args.n_wires, tuple(args.classes), args.max_train, args.max_val, root=args.cifar_root
        )
        out_prefix = "shots_fast_cifar"
        marker_cycle = ["o","s","^","x","d","v","<",">","1","2"]
    else:
        if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
            print("[fast] Provide EuroSAT paths: --eurosat_csv --eurosat_rgb_root --eurosat_tif_root", file=sys.stderr)
            sys.exit(1)
        (Xtr, ytr), (Xva, yva), num_classes = load_eurosat_subset(
            args.n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root,
            classes=args.eurosat_classes, max_train=args.max_train, max_val=args.max_val
        )
        out_prefix = "shots_fast_eurosat"
        marker_cycle = ["|","_","1","2","3","4","+",".","*","P"]

    # Precompute exact probabilities once (frozen random circuit)
    print("[fast] Precomputing p_true for train/val...")
    P_true_tr = precompute_probs(Xtr, args.n_wires, args.layers, seed=123)
    P_true_va = precompute_probs(Xva, args.n_wires, args.layers, seed=123)
    print("[fast] Done.")

    curves_acc, curves_loss = {}, {}
    for i, shots in enumerate(args.shots_levels):
        print(f"\n[fast] Training head with shots={shots if shots!=0 else 'noiseless'}, epsilon={args.epsilon}")
        acc_hist, loss_hist = train_head(P_true_tr, ytr, P_true_va, yva, num_classes,
                                         shots=shots, epsilon=args.epsilon,
                                         epochs=args.epochs, batch=args.batch, lr=5e-2)
        curves_acc[shots] = acc_hist
        curves_loss[shots] = loss_hist

    # Plot ACC
    plt.figure()
    for i, shots in enumerate(args.shots_levels):
        m = marker_cycle[i % len(marker_cycle)]
        label = "noiseless" if shots==0 else f"shots={shots}"
        plt.plot(range(1, len(curves_acc[shots])+1), curves_acc[shots], marker=m, label=label)
    plt.xlabel("epoch")
    plt.ylabel("accuracy (%)")
    plt.legend()
    acc_out = Path(f"{out_prefix}_acc.pdf")
    plt.tight_layout(); plt.savefig(acc_out, bbox_inches="tight")
    print(f"[fast] Wrote {acc_out}")

    # Plot LOSS
    plt.figure()
    for i, shots in enumerate(args.shots_levels):
        m = marker_cycle[i % len(marker_cycle)]
        label = "noiseless" if shots==0 else f"shots={shots}"
        plt.plot(range(1, len(curves_loss[shots])+1), curves_loss[shots], marker=m, label=label)
    plt.xlabel("epoch")
    plt.ylabel("cross-entropy loss")
    plt.legend()
    loss_out = Path(f"{out_prefix}_loss.pdf")
    plt.tight_layout(); plt.savefig(loss_out, bbox_inches="tight")
    print(f"[fast] Wrote {loss_out}")

if __name__ == "__main__":
    main()