#!/usr/bin/env python3
"""
Dataset preprocessing utilities.

Usage:
    python datasets/preprocess.py npy2txt <file.npy>      # Convert numpy to txt
    python datasets/preprocess.py labels <dataset>        # Generate label files
    python datasets/preprocess.py clip <dataset>          # Generate CLIP embeddings
    python datasets/preprocess.py info <dataset>          # Show dataset info

Examples:
    python datasets/preprocess.py npy2txt cifar10.npy
    python datasets/preprocess.py labels mnist
    python datasets/preprocess.py clip cifar10
    python datasets/preprocess.py info mnist
"""

import argparse
import sys
from pathlib import Path
import numpy as np

DATASETS_DIR = Path(__file__).parent

# Dataset configurations
DATASET_INFO = {
    "mnist": {"source": "sklearn", "name": "mnist_784", "n": 70000, "d": 784},
    "fmnist": {"source": "sklearn", "name": "Fashion-MNIST", "n": 70000, "d": 784},
    "cifar10": {"source": "torchvision", "n": 60000, "d": 3072},
    "cifar100": {"source": "torchvision", "n": 60000, "d": 3072},
    "mnist_clip": {"source": "clip", "base": "mnist", "n": 70000, "d": 512},
    "fmnist_clip": {"source": "clip", "base": "fmnist", "n": 70000, "d": 512},
    "cifar10_clip": {"source": "clip", "base": "cifar10", "n": 60000, "d": 512},
    "cifar100_clip": {"source": "clip", "base": "cifar100", "n": 60000, "d": 512},
}


def npy_to_txt(npy_path: Path, output_path: Path = None):
    """Convert numpy array to space-separated text file."""
    X = np.load(npy_path, mmap_mode="r")

    if output_path is None:
        output_path = npy_path.with_suffix(".txt")

    print(f"Converting {npy_path} -> {output_path}")
    print(f"  Shape: {X.shape}")

    with open(output_path, "w") as f:
        for row in X:
            f.write(" ".join(map(str, row)))
            f.write("\n")

    print(f"  Done: {output_path}")


def generate_labels(dataset: str):
    """Generate label files for a dataset."""
    print(f"Generating labels for {dataset}...")

    if dataset in ["mnist", "mnist_clip"]:
        from sklearn.datasets import fetch_openml
        data = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
        labels = data.target.astype(int)

    elif dataset in ["fmnist", "fmnist_clip"]:
        from sklearn.datasets import fetch_openml
        data = fetch_openml('Fashion-MNIST', version=1, as_frame=False, parser='auto')
        labels = data.target.astype(int)

    elif dataset in ["cifar10", "cifar10_clip"]:
        import torchvision
        trainset = torchvision.datasets.CIFAR10(root='/tmp/data', train=True, download=True)
        testset = torchvision.datasets.CIFAR10(root='/tmp/data', train=False, download=True)
        labels = np.concatenate([trainset.targets, testset.targets])

    elif dataset in ["cifar100", "cifar100_clip"]:
        import torchvision
        trainset = torchvision.datasets.CIFAR100(root='/tmp/data', train=True, download=True)
        testset = torchvision.datasets.CIFAR100(root='/tmp/data', train=False, download=True)
        labels = np.concatenate([trainset.targets, testset.targets])

    else:
        print(f"ERROR: Unknown dataset '{dataset}'")
        return

    output_path = DATASETS_DIR / f"{dataset}_labels.txt"
    np.savetxt(output_path, labels, fmt='%d')
    print(f"  Saved {len(labels)} labels to {output_path}")


def generate_clip_embeddings(dataset: str, batch_size: int = 256):
    """Generate CLIP embeddings for a dataset."""
    import torch
    import clip
    from PIL import Image
    from tqdm import tqdm

    # Load base dataset
    npy_path = DATASETS_DIR / f"{dataset}.npy"
    if not npy_path.exists():
        print(f"ERROR: Base dataset not found: {npy_path}")
        return

    print(f"Generating CLIP embeddings for {dataset}...")
    X = np.load(npy_path)
    N = X.shape[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"  Device: {device}")

    model, preprocess = clip.load("ViT-B/32", device=device)
    model.eval()

    def to_rgb(img):
        if img.ndim == 2:
            img = np.stack([img] * 3, axis=-1)
        return Image.fromarray(img.astype(np.uint8))

    embeddings = []
    with torch.no_grad():
        for i in tqdm(range(0, N, batch_size), desc="Computing embeddings"):
            batch = X[i:i + batch_size]
            imgs = [preprocess(to_rgb(img)) for img in batch]
            imgs = torch.stack(imgs).to(device)

            feats = model.encode_image(imgs)
            feats = feats / feats.norm(dim=1, keepdim=True)
            embeddings.append(feats.cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    # Save as npy and txt
    npy_out = DATASETS_DIR / f"{dataset}_clip.npy"
    txt_out = DATASETS_DIR / f"{dataset}_clip.txt"

    np.save(npy_out, embeddings)
    print(f"  Saved: {npy_out} (shape: {embeddings.shape})")

    # Also save as txt
    with open(txt_out, "w") as f:
        for row in embeddings:
            f.write(" ".join(map(str, row)))
            f.write("\n")
    print(f"  Saved: {txt_out}")


def show_info(dataset: str):
    """Show information about a dataset."""
    txt_path = DATASETS_DIR / f"{dataset}.txt"
    labels_path = DATASETS_DIR / f"{dataset}_labels.txt"

    print(f"\nDataset: {dataset}")
    print("=" * 40)

    if txt_path.exists():
        # Read first line to get dimension
        with open(txt_path) as f:
            first_line = f.readline()
            d = len(first_line.strip().split())

        # Count lines
        with open(txt_path) as f:
            n = sum(1 for _ in f)

        print(f"  Data file: {txt_path}")
        print(f"  Points (n): {n}")
        print(f"  Dimensions (d): {d}")
        print(f"  Size: {txt_path.stat().st_size / 1e6:.1f} MB")
    else:
        print(f"  Data file: NOT FOUND ({txt_path})")

    if labels_path.exists():
        labels = np.loadtxt(labels_path, dtype=int)
        n_classes = len(np.unique(labels))
        print(f"  Labels file: {labels_path}")
        print(f"  Classes: {n_classes}")
    else:
        print(f"  Labels file: NOT FOUND")

    # Show expected info
    if dataset in DATASET_INFO:
        info = DATASET_INFO[dataset]
        print(f"\n  Expected: n={info['n']}, d={info['d']}")


def main():
    parser = argparse.ArgumentParser(
        description="Dataset preprocessing utilities",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    subparsers = parser.add_subparsers(dest="command", help="Command to run")

    # npy2txt
    p_npy = subparsers.add_parser("npy2txt", help="Convert numpy to txt")
    p_npy.add_argument("file", help="Input .npy file")
    p_npy.add_argument("--output", "-o", help="Output .txt file (default: same name)")

    # labels
    p_labels = subparsers.add_parser("labels", help="Generate label files")
    p_labels.add_argument("dataset", help="Dataset name (e.g., mnist, cifar10)")

    # clip
    p_clip = subparsers.add_parser("clip", help="Generate CLIP embeddings")
    p_clip.add_argument("dataset", help="Base dataset name (e.g., mnist, cifar10)")
    p_clip.add_argument("--batch-size", type=int, default=256)

    # info
    p_info = subparsers.add_parser("info", help="Show dataset info")
    p_info.add_argument("dataset", help="Dataset name")

    # list
    subparsers.add_parser("list", help="List available datasets")

    args = parser.parse_args()

    if args.command == "npy2txt":
        npy_path = Path(args.file)
        output = Path(args.output) if args.output else None
        npy_to_txt(npy_path, output)

    elif args.command == "labels":
        generate_labels(args.dataset)

    elif args.command == "clip":
        generate_clip_embeddings(args.dataset, args.batch_size)

    elif args.command == "info":
        show_info(args.dataset)

    elif args.command == "list":
        print("Available datasets:")
        for name, info in DATASET_INFO.items():
            print(f"  {name}: n={info['n']}, d={info['d']}")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
