#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Prepare CIFAR-10 and EuroSAT feature files (no model training).
- Downloads datasets (via torchvision) into a local cache.
- Converts images to vectors, fits PCA on the training split, and saves PCA-8 (default) features.
- Outputs:
    ./cifar/features_train.npz, ./cifar/features_test.npz
    ./eurosat/features_train.npz, ./eurosat/features_test.npz

Run examples:
  python prepare_data.py --dataset cifar   --outdir ./cifar   --pca 8
  python prepare_data.py --dataset eurosat --outdir ./eurosat --pca 8
"""
import argparse, os, sys, json
from pathlib import Path
import numpy as np

# Optional imports handled gracefully
def _try_import_torchvision():
    try:
        import torch
        from torchvision import datasets, transforms
        return torch, datasets, transforms
    except Exception as e:
        print("[prepare_data] Please install torchvision first: pip install torch torchvision", file=sys.stderr)
        raise

def _pca_fit(X, k):
    # Center + PCA via SVD
    X = X.astype(np.float32)
    mean = X.mean(axis=0, keepdims=True)
    Xc = X - mean
    # Economy SVD
    U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
    comps = Vt[:k]
    Xk = (Xc @ comps.T)
    return Xk, mean.squeeze().astype(np.float32), comps.astype(np.float32)

def _pca_apply(X, mean, comps):
    X = X.astype(np.float32)
    Xc = X - mean
    return Xc @ comps.T

def _to_gray_np(img_pil):
    # Convert PIL image to grayscale numpy vector
    g = img_pil.convert("L")  # 1 channel
    arr = np.asarray(g, dtype=np.float32) / 255.0
    return arr.reshape(-1)

def _save_npz(path, **arrays):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(path, **arrays)
    print(f"[prepare_data] Wrote {path} ({path.stat().st_size/1024:.1f} KB)")

def prepare_cifar(outdir, pca_k):
    torch, datasets, transforms = _try_import_torchvision()
    # CIFAR-10: 32x32 RGB -> grayscale -> flatten -> PCA
    root = Path("./_cache_cifar")
    train = datasets.CIFAR10(root=root, train=True, download=True, transform=None)
    test  = datasets.CIFAR10(root=root, train=False, download=True, transform=None)
    class_names = train.classes

    def pack(ds):
        X, y = [], []
        for img, label in ds:
            X.append(_to_gray_np(img))  # 32*32 = 1024
            y.append(label)
        return np.stack(X, 0), np.array(y, dtype=np.int64)

    Xtr, ytr = pack(train)
    Xte, yte = pack(test)

    # z-score per feature using train stats
    mu, sigma = Xtr.mean(0), Xtr.std(0) + 1e-8
    Xtr_z = (Xtr - mu) / sigma
    Xte_z = (Xte - mu) / sigma

    Xtr_pca, pmean, pcomps = _pca_fit(Xtr_z, pca_k)
    Xte_pca = _pca_apply(Xte_z, pmean, pcomps)

    outdir = Path(outdir)
    _save_npz(outdir / "features_train.npz",
              X=Xtr_pca.astype(np.float32), y=ytr, class_names=np.array(class_names),
              z_mu=mu.astype(np.float32), z_sigma=sigma.astype(np.float32),
              pca_mean=pmean, pca_components=pcomps)
    _save_npz(outdir / "features_test.npz",
              X=Xte_pca.astype(np.float32), y=yte, class_names=np.array(class_names),
              z_mu=mu.astype(np.float32), z_sigma=sigma.astype(np.float32),
              pca_mean=pmean, pca_components=pcomps)

def prepare_eurosat(outdir, pca_k):
    torch, datasets, transforms = _try_import_torchvision()
    # EuroSAT (RGB): 64x64 RGB -> grayscale -> flatten -> PCA
    # Torchvision provides EuroSAT; ensure torchvision>=0.13
    try:
        EuroSAT = datasets.EuroSAT
    except AttributeError:
        print("[prepare_data] Your torchvision is too old for EuroSAT. Upgrade: pip install 'torchvision>=0.13'", file=sys.stderr)
        raise

    root = Path("./_cache_eurosat")
    ds = EuroSAT(root=root, download=True)  # returns PIL images and label ints
    # Stratified 80/20 split
    X, y = [], []
    for img, label in ds:
        X.append(_to_gray_np(img))  # 64*64 = 4096
        y.append(label)
    X = np.stack(X, 0)
    y = np.array(y, dtype=np.int64)

    # stratified split
    from sklearn.model_selection import StratifiedShuffleSplit
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    idx_train, idx_test = next(sss.split(X, y))
    Xtr, ytr = X[idx_train], y[idx_train]
    Xte, yte = X[idx_test],  y[idx_test]

    # z-score per feature using train stats
    mu, sigma = Xtr.mean(0), Xtr.std(0) + 1e-8
    Xtr_z = (Xtr - mu) / sigma
    Xte_z = (Xte - mu) / sigma

    Xtr_pca, pmean, pcomps = _pca_fit(Xtr_z, pca_k)
    Xte_pca = _pca_apply(Xte_z, pmean, pcomps)

    # class names come from the dataset metadata
    try:
        class_names = ds.classes
    except Exception:
        # Fallback generic names
        nclasses = int(y.max() + 1)
        class_names = [f"class_{i}" for i in range(nclasses)]

    outdir = Path(outdir)
    _save_npz(outdir / "features_train.npz",
              X=Xtr_pca.astype(np.float32), y=ytr, class_names=np.array(class_names),
              z_mu=mu.astype(np.float32), z_sigma=sigma.astype(np.float32),
              pca_mean=pmean, pca_components=pcomps)
    _save_npz(outdir / "features_test.npz",
              X=Xte_pca.astype(np.float32), y=yte, class_names=np.array(class_names),
              z_mu=mu.astype(np.float32), z_sigma=sigma.astype(np.float32),
              pca_mean=pmean, pca_components=pcomps)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", choices=["cifar","eurosat"], required=True,
                    help="Which dataset to prepare.")
    ap.add_argument("--outdir", type=str, required=True, help="Folder to write features_*.npz")
    ap.add_argument("--pca", type=int, default=8, help="PCA components (use 6–12 to match qubits).")
    args = ap.parse_args()

    Path(args.outdir).mkdir(parents=True, exist_ok=True)

    if args.dataset == "cifar":
        prepare_cifar(args.outdir, args.pca)
    else:
        prepare_eurosat(args.outdir, args.pca)

if __name__ == "__main__":
    main()
