#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ablation study over number of quantum layers (n_layers) with NO READOUT NOISE.
- Fixed n_wires = 6
- Works for CIFAR or EuroSAT (same options as your script)
- Sweeps a list of n_layers values (e.g., 1 3 5)
- Saves:
    * per-epoch curves as PDFs (acc, loss) overlayed by n_layers
    * a summary CSV of best/last val metrics per n_layers
"""

import argparse, sys, os
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import pennylane as qml
from pennylane import numpy as pnp

# ---- 16 pt everywhere ----
matplotlib.rcParams.update({
    "font.size": 16,
    "axes.labelsize": 16,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "legend.fontsize": 16,
})

# ---------------- Core math utils ----------------

def softmax(z):
    z = z - pnp.max(z)
    ez = pnp.exp(z)
    return ez / pnp.sum(ez)

def xent_from_logits(logits, y):
    p = softmax(logits)
    return -pnp.log(p[int(y)] + 1e-12)

def angle_map(x):
    return pnp.pi * (1.0 / (1.0 + pnp.exp(-x)))  # sigmoid * pi

# ---------------- QCNN-ish circuit ----------------

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev, interface="autograd")
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

# ---------------- Feature helpers ----------------

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return pnp.array([float(s.mean()) for s in splits], dtype=pnp.float64)

def load_cifar_subset(n_wires, classes=(0,1), max_train=1000, max_val=400, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception:
        print("[ablate-layers] torchvision not available; cannot use CIFAR.", file=sys.stderr)
        sys.exit(1)
    train = datasets.CIFAR10(root=root, train=True,  download=True,  transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True,  transform=None)

    def filter_pack(ds, wanted, limit):
        X, y = [], []
        for img, label in ds:
            if label in wanted:
                arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
                X.append(pool_to_n(arr, n_wires))
                y.append(wanted.index(label))  # remap to 0..C-1
                if limit and len(X) >= limit: break
        return pnp.stack(X, axis=0), pnp.array(y, dtype=int)

    wanted = list(classes)
    Xtr, ytr = filter_pack(train, wanted, max_train)
    Xva, yva = filter_pack(test,  wanted, max_val)   # use test as "val"
    num_classes = len(wanted)
    return (Xtr, ytr), (Xva, yva), num_classes

def load_eurosat_subset(n_wires, csv_path, rgb_root, tif_root, classes=None, max_train=800, max_val=300):
    try:
        import pandas as pd, rasterio
    except Exception:
        print("[ablate-layers] pandas/rasterio required for EuroSAT.", file=sys.stderr)
        sys.exit(1)
    df = pd.read_csv(csv_path)
    if classes is None:
        uniq = sorted(df["Label"].unique().tolist())
        classes = uniq[:2]
        print(f"[ablate-layers] EuroSAT classes defaulted to: {classes}")
    # simple class-balanced take from top
    Xtr, ytr, Xva, yva = [], [], [], []
    per_class_tr = max(1, max_train // len(classes))
    per_class_va = max(1, max_val   // len(classes))
    counts_tr = {c:0 for c in classes}
    counts_va = {c:0 for c in classes}

    for _, row in df.iterrows():
        lab = int(row["Label"])
        if lab not in classes: continue
        rgb_path = os.path.join(rgb_root, row["Filename"])
        tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
        if not (os.path.isfile(rgb_path) and os.path.isfile(tif_path)):
            continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")
        vec = pnp.array(vec, dtype=pnp.float64)

        if counts_tr[lab] < per_class_tr:
            Xtr.append(vec); ytr.append(classes.index(lab)); counts_tr[lab]+=1
        elif counts_va[lab] < per_class_va:
            Xva.append(vec); yva.append(classes.index(lab)); counts_va[lab]+=1
        if all(counts_tr[c]>=per_class_tr for c in classes) and all(counts_va[c]>=per_class_va for c in classes):
            break

    if not Xtr or not Xva:
        print("[ablate-layers] EuroSAT subset too small or paths wrong.", file=sys.stderr); sys.exit(1)

    Xtr = pnp.stack(Xtr, axis=0); ytr = pnp.array(ytr, dtype=int)
    Xva = pnp.stack(Xva, axis=0); yva = pnp.array(yva, dtype=int)
    num_classes = len(classes)
    return (Xtr, ytr), (Xva, yva), num_classes

# ---------------- Model / training ----------------

class VQCClassifier:
    def __init__(self, n_wires, n_layers, num_classes, seed=0):
        self.n_wires = n_wires
        self.n_layers = n_layers
        self.num_classes = num_classes
        self.dim = 2**n_wires
        rng = np.random.default_rng(seed)
        # Parameters with gradients
        self.weights = pnp.array(rng.standard_normal((n_layers, n_wires, 3))*0.2, requires_grad=True)
        self.W = pnp.array(rng.standard_normal((num_classes, self.dim))*0.01, requires_grad=True)
        self.b = pnp.array(np.zeros((num_classes,)), requires_grad=True)
        self.probs_fn = build_probs_fn(n_wires, n_layers)

    def forward_logits_with(self, theta, A, weights, W, b):
        p_true = self.probs_fn(theta, weights)
        # NO NOISE (A = Identity)
        p_meas = A @ p_true
        p_meas = p_meas / pnp.sum(p_meas)
        logits = W @ p_meas + b
        return logits

def train_model_no_noise(model, Xtr, ytr, Xva, yva, epochs=10, lr=1e-2, batch=32):
    opt = qml.AdamOptimizer(lr)
    N = Xtr.shape[0]
    acc_hist, loss_hist = [], []

    A = pnp.eye(2**model.n_wires, dtype=pnp.float64)

    def batch_slice(i):
        start = (i*batch) % N
        end = min(start+batch, N)
        return slice(start, end)

    def batch_loss(weights, W, b, sl):
        loss = 0.0
        for xi, yi in zip(Xtr[sl], ytr[sl]):
            theta = angle_map(xi)
            logits = model.forward_logits_with(theta, A, weights, W, b)
            loss = loss + xent_from_logits(logits, yi)
        return loss / (sl.stop - sl.start)

    weights, W, b = model.weights, model.W, model.b

    for ep in range(epochs):
        batches = max(1, int(np.ceil(N / batch)))
        for i in range(batches):
            sl = batch_slice(i)
            (weights, W, b), _ = opt.step_and_cost(
                lambda w, W_, b_: batch_loss(w, W_, b_, sl),
                weights, W, b
            )

        def eval_split(X, y):
            tot_loss, correct = 0.0, 0
            for xi, yi in zip(X, y):
                theta = angle_map(xi)
                logits = model.forward_logits_with(theta, A, weights, W, b)
                tot_loss += float(xent_from_logits(logits, yi))
                pred = int(pnp.argmax(softmax(logits)))
                correct += int(pred == int(yi))
            return tot_loss/len(X), 100.0*correct/len(X)

        val_loss, val_acc = eval_split(Xva, yva)
        loss_hist.append(val_loss); acc_hist.append(val_acc)
        print(f"[ablate-layers][layers={model.n_layers}] epoch {ep+1:02d}: val_acc={val_acc:5.2f}%  val_loss={val_loss:.4f}")

    model.weights, model.W, model.b = weights, W, b
    return acc_hist, loss_hist

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser(description="Ablation over n_layers with NO readout noise (fixed n_wires=6).")
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--n_layers_list", type=int, nargs="+", default=[1,3,5], help="e.g., 1 3 5")
    ap.add_argument("--n_wires", type=int, default=6)  # fixed at 6 per request
    ap.add_argument("--epochs", type=int, default=10)
    ap.add_argument("--lr", type=float, default=1e-2)
    ap.add_argument("--batch", type=int, default=32)
    # CIFAR options
    ap.add_argument("--classes", type=int, nargs="+", default=[0,1], help="CIFAR class ids (e.g., 0 1)")
    ap.add_argument("--max_train", type=int, default=1000)
    ap.add_argument("--max_val", type=int, default=400)
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    # EuroSAT options
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    ap.add_argument("--eurosat_classes", type=int, nargs="+", default=None, help="EuroSAT label ids (e.g., 0 1)")
    ap.add_argument("--out_prefix", type=str, default="")
    args = ap.parse_args()

    n_wires = args.n_wires
    all_curves_acc = {}
    all_curves_loss = {}
    summary_rows = []

    for L in args.n_layers_list:
        print(f"\n[ablate-layers] ==== n_layers = {L} (no noise, n_wires={n_wires}) ====")

        if args.dataset == "cifar":
            (Xtr, ytr), (Xva, yva), num_classes = load_cifar_subset(
                n_wires, tuple(args.classes), args.max_train, args.max_val, root=args.cifar_root
            )
            out_prefix = args.out_prefix or f"ablate_layers_cifar"
        else:
            if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
                print("[ablate-layers] Provide EuroSAT paths: --eurosat_csv --eurosat_rgb_root --eurosat_tif_root", file=sys.stderr)
                sys.exit(1)
            (Xtr, ytr), (Xva, yva), num_classes = load_eurosat_subset(
                n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root,
                classes=args.eurosat_classes, max_train=args.max_train, max_val=args.max_val
            )
            out_prefix = args.out_prefix or f"ablate_layers_eurosat"

        model = VQCClassifier(n_wires=n_wires, n_layers=L, num_classes=num_classes, seed=42)
        acc_hist, loss_hist = train_model_no_noise(
            model, Xtr, ytr, Xva, yva, epochs=args.epochs, lr=args.lr, batch=args.batch
        )
        all_curves_acc[L] = acc_hist
        all_curves_loss[L] = loss_hist

        best_acc = float(max(acc_hist))
        best_ep = int(np.argmax(acc_hist)) + 1
        last_acc = float(acc_hist[-1])
        last_loss = float(loss_hist[-1])
        summary_rows.append((L, best_acc, best_ep, last_acc, last_loss))

    # Save summary CSV
    import csv
    summary_path = Path(f"{out_prefix}_summary_layers.csv")
    with summary_path.open("w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["n_layers", "best_val_acc_percent", "best_epoch", "last_val_acc_percent", "last_val_loss"])
        for row in summary_rows:
            w.writerow(row)
    print(f"[ablate-layers] Wrote {summary_path}")

    # Plot ACC vs epoch (overlay by layers)
    plt.figure()
    for L in sorted(all_curves_acc.keys()):
        plt.plot(range(1, len(all_curves_acc[L])+1), all_curves_acc[L], marker="o", label=f"layers={L}")
    plt.xlabel("epoch"); plt.ylabel("accuracy (%)")
    plt.legend(); plt.tight_layout()
    plt.grid(True)
    out_acc = Path(f"{out_prefix}_acc_vs_epoch_by_layers.pdf")
    plt.savefig(out_acc, bbox_inches="tight"); print(f"[ablate-layers] Wrote {out_acc}")

    # Plot LOSS vs epoch (overlay by layers)
    plt.figure()
    for L in sorted(all_curves_loss.keys()):
        plt.plot(range(1, len(all_curves_loss[L])+1), all_curves_loss[L], marker="s", label=f"layers={L}")
    plt.xlabel("epoch"); plt.ylabel("cross-entropy loss")
    plt.legend(); plt.tight_layout()
    plt.grid(True)
    out_loss = Path(f"{out_prefix}_loss_vs_epoch_by_layers.pdf")
    plt.savefig(out_loss, bbox_inches="tight"); print(f"[ablate-layers] Wrote {out_loss}")

if __name__ == "__main__":
    main()
