# # CIFAR-10, random θ (paper-like), 8 qubits, 2000 samples, shots ∈ {2,126,256,512,1048}, ε=0.02
# python generate_noise_csv.py --dataset cifar \
#   --out_csv ./cifar/noise_calibration_cifar.csv \
#   --n_qubits 8 --num_samples 2000 \
#   --shots_levels 2 126 256 512 1048 \
#   --epsilon 0.02 --theta_mode random \
#   --cifar_root ./_cache_cifar

# # EuroSAT, random θ, 8 qubits, 2000 samples, same shots and epsilon
# python generate_noise_csv.py --dataset eurosat \
#   --out_csv ./eurosat/noise_calibration_eurosat.csv \
#   --n_qubits 8 --num_samples 2000 \
#   --shots_levels 2 126 256 512 1048 \
#   --epsilon 0.02 --theta_mode random \
#   --eurosat_csv      ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
#   --eurosat_rgb_root ~/quantum/rem/eurosat/EuroSAT_extracted \
#   --eurosat_tif_root ~/quantum/rem/eurosat/EuroSAT_extracted/allBands

# # (Optional) use dataset FEATURES to map inputs to θ (binds rows to real samples)
# python generate_noise_csv.py --dataset cifar \
#   --out_csv ./cifar/noise_features_cifar.csv \
#   --n_qubits 8 --num_samples 2000 \
#   --shots_levels 2 126 256 512 1048 \
#   --epsilon 0.02 --theta_mode features \
#   --cifar_root ./_cache_cifar --classes 0 1

# Notes
# -----
# - Following Kim et al., training pairs come from single-qubit Ry rotations with
#   independent random angles per qubit and p̂ obtained from measured counts divided by shots.
# - We JSON-encode vectors so you can load rows directly in Python (or Pandas) later.


import argparse, sys, os, json, csv, time
from pathlib import Path
import numpy as np

# ---------------- Helpers ----------------

def now_ts():
    return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())

def angle_map_features(x):
    """Map a real-valued feature vector x to θ ∈ [0, 2π) per qubit."""
    return 2.0*np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * 2π

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([float(s.mean()) for s in splits], dtype=np.float64)

def p_from_theta_product(theta):
    """Compute ideal probs for an n-qubit product state with Ry(θ) per qubit.
    For bit b_i in {0,1}: P(bit=0)=cos^2(θ_i/2), P(bit=1)=sin^2(θ_i/2).
    Joint distribution is the Kronecker product of per-qubit [p0_i, p1_i].
    Returns vector of length 2^n ordered by integer basis states 0..2^n-1.
    """
    p = np.array([1.0], dtype=np.float64)
    for th in theta:
        p0 = np.cos(th/2.0)**2
        p1 = np.sin(th/2.0)**2
        p = np.kron(p, np.array([p0, p1], dtype=np.float64))
    # normalize against round-off
    p = np.clip(p, 0.0, 1.0)
    p = p / p.sum()
    return p

def confusion_1q(e0, e1):
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=np.float64)

def assignment_matrix(n_qubits, e0, e1):
    A = np.array([[1.0]], dtype=np.float64)
    M = confusion_1q(e0, e1)
    for _ in range(n_qubits):
        A = np.kron(A, M)
    return A  # (2^n, 2^n)

def apply_readout_and_sample(p_true, shots, A):
    # Apply assignment matrix then draw multinomial counts; return p_meas, counts."""
    p_meas = A @ p_true
    p_meas = np.clip(p_meas, 0.0, 1.0)
    p_meas = p_meas / p_meas.sum()
    counts = np.random.multinomial(int(shots), p_meas) if shots > 0 else (p_meas * 0).astype(int)
    pm = counts / max(1, int(shots))
    return pm, counts

# ---------------- CIFAR utilities (only if theta_mode=features) ----------------

def load_cifar_features(n_qubits, num_samples, classes=None, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception as e:
        print("[csv] torchvision not available; cannot use CIFAR features.", file=sys.stderr)
        sys.exit(1)
    train = datasets.CIFAR10(root=root, train=True,  download=True,  transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True,  transform=None)
    DS = list(train) + list(test)

    feats, labels, sources = [], [], []
    wanted = set(classes) if classes else None
    for idx, (img, label) in enumerate(DS):
        if (wanted is None) or (label in wanted):
            arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
            vec = pool_to_n(arr, n_qubits)  # length n_qubits
            feats.append(vec); labels.append(int(label)); sources.append(f"cifar_idx_{idx}")
            if len(feats) >= num_samples: break
    return np.stack(feats), np.array(labels, dtype=int), sources

# ---------------- EuroSAT utilities (only if theta_mode=features) ----------------

def load_eurosat_features(n_qubits, num_samples, csv_path, tif_root, classes=None):
    try:
        import pandas as pd, rasterio
    except Exception as e:
        print("[csv] pandas/rasterio required for EuroSAT features.", file=sys.stderr)
        sys.exit(1)
    df = pd.read_csv(os.path.expanduser(csv_path))
    wanted = set(classes) if classes else None
    feats, labels, sources = [], [], []
    for _, row in df.iterrows():
        lab = int(row["Label"])
        if (wanted is not None) and (lab not in wanted):
            continue
        tif_path = os.path.join(os.path.expanduser(tif_root), row["Filename"].replace(".jpg",".tif"))
        if not os.path.isfile(tif_path): continue
        with rasterio.open(tif_path) as ds:
            tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
        band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
        vec = band_means[:n_qubits] if len(band_means) >= n_qubits else np.pad(band_means, (0, n_qubits-len(band_means)), mode="edge")
        feats.append(vec); labels.append(lab); sources.append(row["Filename"])
        if len(feats) >= num_samples: break
    if not feats:
        print("[csv] EuroSAT features empty; check paths/classes.", file=sys.stderr); sys.exit(1)
    return np.stack(feats), np.array(labels, dtype=int), sources

# ---------------- Main generator ----------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True)
    ap.add_argument("--out_csv", type=str, required=True)
    ap.add_argument("--n_qubits", type=int, default=8)
    ap.add_argument("--num_samples", type=int, default=2000)
    ap.add_argument("--shots_levels", type=int, nargs="+", default=[8192])
    ap.add_argument("--epsilon", type=float, default=0.02, help="base misclassification rate; e0=0.9*eps, e1=1.1*eps")
    ap.add_argument("--e0", type=float, default=None, help="override misclass 0->1")
    ap.add_argument("--e1", type=float, default=None, help="override misclass 1->0")
    ap.add_argument("--theta_mode", choices=["random","features"], default="random")
    ap.add_argument("--seed", type=int, default=123)

    # CIFAR opts
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    ap.add_argument("--classes", type=int, nargs="+", default=None, help="optional subset for CIFAR")

    # EuroSAT opts
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")  # unused here; kept for parity
    ap.add_argument("--eurosat_tif_root", type=str, default="")

    args = ap.parse_args()
    rng = np.random.default_rng(args.seed)
    out_csv = os.path.expanduser(args.out_csv)
    os.makedirs(os.path.dirname(out_csv), exist_ok=True)

    # misclassification parameters
    e0 = args.e0 if args.e0 is not None else 0.9*args.epsilon
    e1 = args.e1 if args.e1 is not None else 1.1*args.epsilon
    A = assignment_matrix(args.n_qubits, e0, e1)

    # Prepare per-sample metadata + θ
    samples = []
    if args.theta_mode == "random":
        # paper-like: random independent θ ∈ [0,2π)
        labels = np.full((args.num_samples,), -1, dtype=int)
        sources = [f"{args.dataset}_rand_{i}" for i in range(args.num_samples)]
        thetas = rng.uniform(0.0, 2.0*np.pi, size=(args.num_samples, args.n_qubits))
    else:
        if args.dataset == "cifar":
            X, labels, sources = load_cifar_features(args.n_qubits, args.num_samples, classes=args.classes, root=os.path.expanduser(args.cifar_root))
        else:
            if not (args.eurosat_csv and args.eurosat_tif_root):
                print("[csv] EuroSAT features require --eurosat_csv and --eurosat_tif_root", file=sys.stderr); sys.exit(1)
            X, labels, sources = load_eurosat_features(args.n_qubits, args.num_samples, args.eurosat_csv, args.eurosat_tif_root, classes=args.classes)
        thetas = angle_map_features(X)

    # Open CSV
    fieldnames = [
        "id","dataset","source","label","n_qubits","shots","epsilon","e0","e1",
        "theta_json","p_true_json","p_meas_json","counts_json","timestamp"
    ]
    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()

        row_id = 0
        for i in range(args.num_samples):
            theta = thetas[i]
            p_true = p_from_theta_product(theta)  # analytic p
            for shots in args.shots_levels:
                p_meas, counts = apply_readout_and_sample(p_true, int(shots), A)
                row = {
                    "id": row_id,
                    "dataset": args.dataset,
                    "source": sources[i],
                    "label": int(labels[i]),
                    "n_qubits": args.n_qubits,
                    "shots": int(shots),
                    "epsilon": float(args.epsilon),
                    "e0": float(e0),
                    "e1": float(e1),
                    "theta_json": json.dumps([float(x) for x in theta]),
                    "p_true_json": json.dumps([float(x) for x in p_true]),
                    "p_meas_json": json.dumps([float(x) for x in p_meas]),
                    "counts_json": json.dumps([int(x) for x in counts]),
                    "timestamp": now_ts(),
                }
                w.writerow(row)
                row_id += 1

    print(f"[csv] Wrote {row_id} rows to {out_csv}")
    print(f"[csv] Settings: dataset={args.dataset}, n_qubits={args.n_qubits}, samples={args.num_samples}, shots_levels={args.shots_levels}, eps={args.epsilon}")

if __name__ == "__main__":
    main()