# python readout_effect_drift.py --dataset cifar \
#   --n_wires 8 --layers 2 --T 40 --base_e0 0.02 --base_e1 0.03 --amp 0.4 --period 40 \
#   --pairs 0-1,2-3 --rho 0.3


# python readout_effect_drift.py --dataset eurosat \
#   --n_wires 8 --layers 2 --T 40 --base_e0 0.02 --base_e1 0.03 --amp 0.4 --period 40 \
#   --pairs 0-1,2-3 --rho 0.3 \
#   --eurosat_csv      ~/quantum/rem/eurosat/EuroSAT_extracted/train.csv \
#   --eurosat_rgb_root ~/quantum/rem/eurosat/EuroSAT_extracted \
#   --eurosat_tif_root ~/quantum/rem/eurosat/EuroSAT_extracted/allBands


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
readout_effect_drift.py
-----------------------
Shows **temporal drift** and **pairwise-correlated** readout effects on a tiny VQC.
Plots are generated **separately** per dataset (CIFAR or EuroSAT), with **no titles**,
16pt fonts, and markers set per dataset:
  - CIFAR: '.' marker
  - EuroSAT: '|' marker
Two curves per plot:
  - independent readout (tensor product confusion)
  - pairwise-correlated readout (selected qubit pairs, rho > 0)

Outputs (examples):
  readout_drift_cifar_tvd.pdf
  readout_drift_cifar_z0err.pdf
  readout_drift_eurosat_tvd.pdf
  readout_drift_eurosat_z0err.pdf
"""

import argparse, sys, os, math
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import pennylane as qml

# ---------- Readout models ----------

def confusion_1q(e0, e1):
    # [[P(0->0), P(1->0)],
    #  [P(0->1), P(1->1)]]
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=float)

def kron_list(mats):
    A = np.array([[1.0]], dtype=float)
    for M in mats:
        A = np.kron(A, M)
    return A

def assignment_matrix_indep(n_wires, e0, e1):
    Ms = [confusion_1q(e0, e1) for _ in range(n_wires)]
    return kron_list(Ms)

def correlated_pair_matrix(e0, e1, rho):
    """Return a 4x4 confusion for a 2-qubit measurement with mild positive correlation.
    Start from independent kron(confusion_1q, confusion_1q), then shift a fraction 'rho'
    of the single-flip probability mass toward {both-flip, none-flip} outcomes.

    Columns index TRUE states in order 00,01,10,11; rows index MEASURED states in same order.
    """
    base = np.kron(confusion_1q(e0, e1), confusion_1q(e0, e1)).copy()
    # For each TRUE column, redistribute probability among measured outcomes
    # indices relative to the true state:
    #   none = same index; s1 = flip first only; s2 = flip second only; both = flip both
    # We'll identify these positions by XOR masks.
    for col in range(4):
        # measured indices corresponding to bit-flips
        none = col
        s1   = col ^ 0b10  # flip first bit (MSB)
        s2   = col ^ 0b01  # flip second bit (LSB)
        both = col ^ 0b11  # flip both
        p_none = base[none, col]
        p_s1   = base[s1,   col]
        p_s2   = base[s2,   col]
        p_both = base[both, col]
        # move rho * (p_s1 + p_s2) equally into none & both
        shift = rho * (p_s1 + p_s2)
        base[s1,   col] = (1 - rho) * p_s1
        base[s2,   col] = (1 - rho) * p_s2
        base[none, col] = p_none + 0.5 * shift
        base[both, col] = p_both + 0.5 * shift
    # column-stochasticity preserved
    return base

def assignment_matrix_corr(n_wires, e0, e1, pairs, rho):
    """Build an assignment matrix with correlated 2-qubit confusion on given pairs.
    Wires are ordered [0..n_wires-1]. Pairs are tuples like (0,1), (2,3).
    We lay out matrices left-to-right in Kronecker product consistent with that order.
    """
    i = 0
    mats = []
    pairs = set(tuple(sorted(p)) for p in pairs)
    while i < n_wires:
        if i+1 < n_wires and (i, i+1) in pairs:
            mats.append(correlated_pair_matrix(e0, e1, rho))
            i += 2
        else:
            mats.append(confusion_1q(e0, e1))
            i += 1
    return kron_list(mats)

# ---------- Circuit + utilities ----------

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev)
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def angle_map(x):
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

def z_expect_from_probs(p, n_wires, wire=0):
    val = 0.0
    for idx, px in enumerate(p):
        bit = (idx >> (n_wires-1-wire)) & 1
        val += px * (1.0 if bit == 0 else -1.0)
    return val

def pool_to_n(arr2d, n):
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([s.mean() for s in splits], dtype=float)

# ---------- CIFAR helper ----------
def cifar_vec(n_wires, root="./_cache_cifar"):
    try:
        from torchvision import datasets
    except Exception as e:
        print("[drift] torchvision not available; CIFAR skipped.", file=sys.stderr)
        return None
    try:
        ds = datasets.CIFAR10(root=root, train=True, download=True, transform=None)
    except Exception:
        try:
            ds = datasets.CIFAR10(root=root, train=True, download=False, transform=None)
        except Exception as e2:
            print(f"[drift] CIFAR not available: {e2}", file=sys.stderr)
            return None
    img, _ = ds[0]
    arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
    return pool_to_n(arr, n_wires)

# ---------- EuroSAT helper ----------
def eurosat_vec(n_wires, csv_path, rgb_root, tif_root):
    try:
        import pandas as pd, rasterio
    except Exception as e:
        print("[drift] pandas/rasterio missing; EuroSAT skipped.", file=sys.stderr)
        return None
    df = pd.read_csv(csv_path)
    if len(df) == 0:
        print("[drift] EuroSAT CSV empty.", file=sys.stderr)
        return None
    row = df.iloc[0]
    import os
    rgb_path = os.path.join(rgb_root, row["Filename"])
    tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg",".tif"))
    if not (os.path.isfile(rgb_path) and os.path.isfile(tif_path)):
        print(f"[drift] Missing files:\n  RGB: {rgb_path}\n  TIF: {tif_path}", file=sys.stderr)
        return None
    with rasterio.open(tif_path) as ds:
        tif = ds.read().astype(float) / 10000.0  # (bands,H,W)
    band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)
    return band_means[:n_wires] if len(band_means) >= n_wires else np.pad(band_means, (0, n_wires-len(band_means)), mode="edge")

# ---------- Drift schedule ----------
def drift_eps(t, base0, base1, amp=0.4, period=40):
    s = math.sin(2*math.pi*t/period)
    e0 = max(1e-5, base0*(1 + amp*s))
    e1 = max(1e-5, base1*(1 - 0.5*amp*s))
    return e0, e1

# ---------- Main ----------
def run_dataset(name, marker, vec, n_wires, n_layers, T, base0, base1, amp, period, pairs, rho):
    if vec is None:
        print(f"[drift] {name} vector unavailable; skipping.", file=sys.stderr)
        return
    theta = angle_map(vec)
    probs_fn = build_probs_fn(n_wires, n_layers)
    # Fixed circuit state (drift only in readout stage)
    p_true = probs_fn(theta, np.random.default_rng(7).standard_normal((n_layers, n_wires, 3))*0.3)

    tvd_indep, tvd_corr = [], []
    zerr_indep, zerr_corr = [], []
    times = list(range(T))
    for t in times:
        e0, e1 = drift_eps(t, base0, base1, amp=amp, period=period)
        A_ind = assignment_matrix_indep(n_wires, e0, e1)
        A_cor = assignment_matrix_corr(n_wires, e0, e1, pairs, rho)

        pm_ind = A_ind @ p_true; pm_ind = pm_ind/pm_ind.sum()
        pm_cor = A_cor @ p_true; pm_cor = pm_cor/pm_cor.sum()

        tvd_indep.append(0.5*np.abs(pm_ind - p_true).sum())
        tvd_corr.append(0.5*np.abs(pm_cor - p_true).sum())

        zt = z_expect_from_probs(p_true, n_wires, wire=0)
        z_ind = z_expect_from_probs(pm_ind, n_wires, wire=0)
        z_cor = z_expect_from_probs(pm_cor, n_wires, wire=0)

        zerr_indep.append(abs(z_ind - zt))
        zerr_corr.append(abs(z_cor - zt))

    # plotting (no titles)
    plt.rcParams.update({"font.size": 16})

    # TVD vs time
    plt.figure()
    plt.plot(times, tvd_indep, marker=marker, linestyle="-",  label="independent")
    plt.plot(times, tvd_corr,  marker=marker, linestyle="--", label=f"pairwise-corr (rho={rho})")
    plt.xlabel("time index")
    plt.ylabel("Total Variation Distance (TVD)")
    plt.legend()
    out1 = Path(f"readout_drift_{name}_tvd.pdf")
    plt.tight_layout(); plt.savefig(out1, bbox_inches="tight")
    print(f"[drift] Wrote {out1}")

    # Z error vs time
    plt.figure()
    plt.plot(times, zerr_indep, marker=marker, linestyle="-",  label="independent")
    plt.plot(times, zerr_corr,  marker=marker, linestyle="--", label=f"pairwise-corr (rho={rho})")
    plt.xlabel("time index")
    plt.ylabel("|<Z_0>_meas - <Z_0>_true|")
    plt.legend()
    out2 = Path(f"readout_drift_{name}_z0err.pdf")
    plt.tight_layout(); plt.savefig(out2, bbox_inches="tight")
    print(f"[drift] Wrote {out2}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat","both"], default="both", help="Which dataset to process.")
    ap.add_argument("--n_wires", type=int, default=8)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--T", type=int, default=40, help="Number of time steps")
    ap.add_argument("--base_e0", type=float, default=0.02, help="Base P(0->1)")
    ap.add_argument("--base_e1", type=float, default=0.03, help="Base P(1->0)")
    ap.add_argument("--amp", type=float, default=0.4, help="Relative drift amplitude")
    ap.add_argument("--period", type=int, default=40, help="Drift period (steps)")
    ap.add_argument("--pairs", type=str, default="0-1,2-3", help="Comma-separated qubit pairs like '0-1,2-3'")
    ap.add_argument("--rho", type=float, default=0.3, help="Pairwise-correlation strength (0..1)")
    # Paths
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar")
    ap.add_argument("--eurosat_csv", type=str, default="")
    ap.add_argument("--eurosat_rgb_root", type=str, default="")
    ap.add_argument("--eurosat_tif_root", type=str, default="")
    args = ap.parse_args()

    # Parse pairs
    pairs = []
    if args.pairs.strip():
        for token in args.pairs.split(","):
            a,b = token.split("-")
            pairs.append((int(a), int(b)))

    # CIFAR
    if args.dataset in ("cifar","both"):
        v_cifar = cifar_vec(args.n_wires, root=args.cifar_root)
        run_dataset("cifar", ".", v_cifar, args.n_wires, args.layers, args.T,
                    args.base_e0, args.base_e1, args.amp, args.period, pairs, args.rho)

    # EuroSAT
    if args.dataset in ("eurosat","both"):
        if not (args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root):
            print("[drift] Provide EuroSAT paths: --eurosat_csv --eurosat_rgb_root --eurosat_tif_root", file=sys.stderr)
        else:
            v_eu = eurosat_vec(args.n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root)
            run_dataset("eurosat", "|", v_eu, args.n_wires, args.layers, args.T,
                        args.base_e0, args.base_e1, args.amp, args.period, pairs, args.rho)

if __name__ == "__main__":
    main()
