#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
readout_effect_both.py
----------------------
Show the effect of READOUT (assignment) noise on a tiny VQC using:
- one sample from CIFAR-10 (download via torchvision, or fallback to synthetic)
- one sample from EuroSAT (user-provided CSV + RGB + TIF folders)

It computes the exact probability vector p_true for a fixed VQC, applies a
tensor-product assignment (confusion) matrix A(e0,e1) to simulate readout noise,
and sweeps epsilon in [0, eps_max]. It then overlays two curves (CIFAR, EuroSAT)
for:
  1) Total Variation Distance TVD(p_meas, p_true) vs epsilon
  2) Observable error |<Z_0>_meas - <Z_0>_true| vs epsilon

Markers:
  - CIFAR: '.'
  - EuroSAT: '|'
All fonts are size 16. Outputs saved as PDF in the current directory.
"""

import argparse, sys, os, json, math
from pathlib import Path

import numpy as np
import matplotlib
matplotlib.use("Agg")  # for headless servers
import matplotlib.pyplot as plt

# PennyLane for exact state-vector probabilities
import pennylane as qml

# Optional heavy deps imported lazily when needed
def _try_import_torchvision():
    import importlib
    try:
        tv = importlib.import_module("torchvision")
        from torchvision import datasets, transforms
        return tv, datasets, transforms
    except Exception as e:
        print("[readout_effect] torchvision not available; CIFAR-10 will be skipped.", file=sys.stderr)
        return None, None, None

def _try_import_pil():
    import importlib
    try:
        Image = importlib.import_module("PIL.Image")
        return Image
    except Exception as e:
        print("[readout_effect] PIL not available; EuroSAT RGB will be skipped.", file=sys.stderr)
        return None

def _try_import_rasterio():
    import importlib
    try:
        rio = importlib.import_module("rasterio")
        return rio
    except Exception as e:
        print("[readout_effect] rasterio not available; EuroSAT TIF will be skipped.", file=sys.stderr)
        return None

# ---------------- Readout model ----------------

def confusion_1q(e0, e1):
    # [[P(0->0), P(1->0)],
    #  [P(0->1), P(1->1)]]
    return np.array([[1-e0, e1],
                     [e0,   1-e1]], dtype=float)

def kronN(mats):
    A = np.array([[1.0]], dtype=float)
    for M in mats:
        A = np.kron(A, M)
    return A

def assignment_matrix(n_wires, e0, e1):
    Ms = [confusion_1q(e0, e1) for _ in range(n_wires)]
    return kronN(Ms)  # shape (2^n, 2^n)

def z_expect_from_probs(p, n_wires, wire=0):
    # <Z_wire> = sum_x p(x) * (-1)^(bit_wire)
    val = 0.0
    for idx, px in enumerate(p):
        bit = (idx >> (n_wires-1-wire)) & 1
        val += px * (1.0 if bit == 0 else -1.0)
    return val

# ---------------- Tiny VQC ----------------

def build_probs_fn(n_wires, n_layers):
    dev = qml.device("default.qubit", wires=n_wires, shots=None)
    @qml.qnode(dev)
    def probs_fn(theta, weights):
        qml.AngleEmbedding(theta, wires=range(n_wires), rotation="Y")
        qml.StronglyEntanglingLayers(weights, wires=range(n_wires))
        return qml.probs(wires=range(n_wires))
    return probs_fn

def angle_map(x):
    # map R^d -> [0, pi]^d
    return np.pi * (1.0 / (1.0 + np.exp(-x)))  # sigmoid * pi

# ---------------- CIFAR sample -> vector ----------------

def get_cifar_vec(n_wires, root="./_cache_cifar", seed=0):
    tv, datasets, transforms = _try_import_torchvision()
    if tv is None:
        return None

    try:
        ds = datasets.CIFAR10(root=root, train=True, download=True, transform=None)
    except Exception as e:
        print(f"[readout_effect] CIFAR download failed ({e}); trying download=False...", file=sys.stderr)
        try:
            ds = datasets.CIFAR10(root=root, train=True, download=False, transform=None)
        except Exception as e2:
            print(f"[readout_effect] CIFAR not available locally either: {e2}", file=sys.stderr)
            return None

    # take a deterministic sample
    img, label = ds[0]
    # grayscale -> vector of length n_wires using 1D pooling
    arr = np.asarray(img.convert("L"), dtype=float) / 255.0  # 32x32
    vec = _pool_to_n(arr, n_wires)
    return vec

def _pool_to_n(arr2d, n):
    # simple equal-chunk pooling of a flattened image to n features
    v = arr2d.reshape(-1)
    splits = np.array_split(v, n)
    return np.array([s.mean() for s in splits], dtype=float)

# ---------------- EuroSAT sample -> vector ----------------

def get_eurosat_vec(n_wires, csv_path, rgb_root, tif_root):
    Image = _try_import_pil()
    rio = _try_import_rasterio()
    if Image is None or rio is None:
        return None

    import pandas as pd
    df = pd.read_csv(csv_path)
    if len(df) == 0:
        print("[readout_effect] EuroSAT CSV has no rows.", file=sys.stderr)
        return None

    row = df.iloc[0]
    rgb_path = os.path.join(rgb_root, row["Filename"])
    tif_path = os.path.join(tif_root, row["Filename"].replace(".jpg", ".tif"))

    if not os.path.isfile(rgb_path) or not os.path.isfile(tif_path):
        print(f"[readout_effect] Missing files:\n  RGB: {rgb_path}\n  TIF: {tif_path}", file=sys.stderr)
        return None

    # RGB not strictly needed; TIF holds 13 bands.
    with rio.open(tif_path) as ds:
        tif = ds.read()  # shape (bands, H, W)
    tif = tif.astype(float) / 10000.0  # scale like user's code

    # summarize to n_wires features: mean per band, take first n_wires bands
    band_means = tif.reshape(tif.shape[0], -1).mean(axis=1)  # length = bands
    if len(band_means) >= n_wires:
        vec = band_means[:n_wires]
    else:
        pad = np.pad(band_means, (0, n_wires - len(band_means)), mode="edge")
        vec = pad
    return vec

# ---------------- Sweep function ----------------

def sweep_readout(theta, n_wires, n_layers, eps_max=0.15, steps=16, seed=7):
    rng = np.random.default_rng(seed)
    weights = rng.standard_normal((n_layers, n_wires, 3)) * 0.3
    probs_fn = build_probs_fn(n_wires, n_layers)
    p_true = probs_fn(theta, weights)

    E = np.linspace(0.0, eps_max, steps)
    tvd_vals, z0_err = [], []
    for e in E:
        A = assignment_matrix(n_wires, e0=e*0.9, e1=e*1.1)  # slight asymmetry
        p_meas = A @ p_true
        p_meas = p_meas / p_meas.sum()

        tvd = 0.5 * np.abs(p_meas - p_true).sum()
        tvd_vals.append(tvd)

        z_true = z_expect_from_probs(p_true, n_wires, wire=0)
        z_meas = z_expect_from_probs(p_meas, n_wires, wire=0)
        z0_err.append(abs(z_meas - z_true))
    return E, np.array(tvd_vals), np.array(z0_err)

# ---------------- Main ----------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--n_wires", type=int, default=8, help="Number of qubits/features (use 4,6,8,...)")
    ap.add_argument("--layers", type=int, default=2, help="VQC depth (StronglyEntanglingLayers)")
    ap.add_argument("--eps_max", type=float, default=0.15, help="Max epsilon for readout misclassification rate")
    ap.add_argument("--steps", type=int, default=16, help="Number of epsilon points in the sweep")
    # CIFAR params
    ap.add_argument("--cifar_root", type=str, default="./_cache_cifar", help="Torchvision cache/root for CIFAR")
    # EuroSAT params
    ap.add_argument("--eurosat_csv", type=str, default="", help="Path to train.csv/validation.csv/test.csv")
    ap.add_argument("--eurosat_rgb_root", type=str, default="", help="Folder containing EuroSAT RGB .jpg files")
    ap.add_argument("--eurosat_tif_root", type=str, default="", help="Folder containing EuroSAT 13-band .tif files")
    args = ap.parse_args()

    plt.rcParams.update({"font.size": 16})

    # --- CIFAR vector ---
    cifar_vec = get_cifar_vec(args.n_wires, root=args.cifar_root)
    if cifar_vec is not None:
        theta_cifar = angle_map(cifar_vec)
        E_c, TVD_c, Zerr_c = sweep_readout(theta_cifar, args.n_wires, args.layers, args.eps_max, args.steps)
    else:
        E_c = TVD_c = Zerr_c = None

    # --- EuroSAT vector ---
    eurosat_vec = None
    if args.eurosat_csv and args.eurosat_rgb_root and args.eurosat_tif_root:
        eurosat_vec = get_eurosat_vec(args.n_wires, args.eurosat_csv, args.eurosat_rgb_root, args.eurosat_tif_root)
    if eurosat_vec is not None:
        theta_eu = angle_map(eurosat_vec)
        E_e, TVD_e, Zerr_e = sweep_readout(theta_eu, args.n_wires, args.layers, args.eps_max, args.steps)
    else:
        E_e = TVD_e = Zerr_e = None

    # --- Plot TVD ---
    plt.figure()
    if E_c is not None:
        plt.plot(E_c, TVD_c, marker=".", label="CIFAR-10")
    if E_e is not None:
        plt.plot(E_e, TVD_e, marker="|", label="EuroSAT")
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("Total Variation Distance (TVD)")
    plt.title("Readout Noise: Histogram Distortion")
    if E_c is not None or E_e is not None:
        plt.legend()
    out1 = Path("readout_noise_effect_tvd.pdf")
    plt.tight_layout(); plt.savefig(out1, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out1}")

    # --- Plot observable error ---
    plt.figure()
    if E_c is not None:
        plt.plot(E_c, Zerr_c, marker=".")  # CIFAR-10
    if E_e is not None:
        plt.plot(E_e, Zerr_e, marker="|")  # EuroSAT
    plt.xlabel("epsilon (readout misclassification rate)")
    plt.ylabel("|<Z_0>_meas - <Z_0>_true|")
    plt.title("Readout Noise: Observable Error")
    out2 = Path("readout_noise_effect_z0err.pdf")
    plt.tight_layout(); plt.savefig(out2, bbox_inches="tight")
    print(f"[readout_effect] Wrote {out2}")

    if E_c is None and E_e is None:
        print("[readout_effect] Nothing plotted (no datasets available). Provide CIFAR or EuroSAT paths.", file=sys.stderr)

if __name__ == "__main__":
    main()
