import os
import argparse
import json
import importlib.util
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
from PIL import Image


def list_images(images_dir: str, exts=(".jpg", ".jpeg", ".png", ".bmp", ".webp")) -> List[str]:
    paths: List[str] = []
    for root, _, files in os.walk(images_dir):
        for fn in files:
            if fn.lower().endswith(exts):
                paths.append(os.path.join(root, fn))
    paths.sort()
    return paths


def read_list_file(list_file: str) -> List[str]:
    with open(list_file, "r", encoding="utf-8") as f:
        paths = [line.strip() for line in f if line.strip()]
    return paths


def is_all_black_or_white(img: Image.Image) -> bool:
    """Check if image is pure black or white (RGB channels are all 0 or 255)."""
    if img.mode != "RGB":
        img = img.convert("RGB")
    extrema = img.getextrema()  # ((minR,maxR),(minG,maxG),(minB,maxB))
    is_black = all(ch[0] == 0 and ch[1] == 0 for ch in extrema)
    is_white = all(ch[0] == 255 and ch[1] == 255 for ch in extrema)
    return is_black or is_white


def build_labels(
    image_paths: List[str],
    mode: str = "subdir",
    single_class_label: int = 0,
) -> Tuple[np.ndarray, Optional[List[str]]]:
    """
    Generate labels based on file paths:
      - mode == 'subdir': Use parent directory name as class name, map to integer labels; return class_names
      - mode == 'single': Mark all as same class single_class_label; return None
    """
    if mode not in ("subdir", "single"):
        raise ValueError("label_mode only supports 'subdir' or 'single'")

    if mode == "single":
        labels = np.full((len(image_paths),), int(single_class_label), dtype=np.int64)
        return labels, None

    # subdir mode
    # Use parent directory name as class name
    name_to_label: Dict[str, int] = {}
    class_names: List[str] = []
    labels = np.zeros((len(image_paths),), dtype=np.int64)
    for i, p in enumerate(image_paths):
        parent = os.path.basename(os.path.dirname(p))
        if parent not in name_to_label:
            name_to_label[parent] = len(name_to_label)
            class_names.append(parent)
        labels[i] = name_to_label[parent]
    return labels, class_names


def load_clip(model_name: str = "ViT-B/32", device: str = "cuda"):
    try:
        import clip  # type: ignore
    except Exception as e:
        raise RuntimeError(
            "CLIP library not found, please install: pip install git+https://github.com/openai/CLIP.git"
        ) from e

    model, preprocess = clip.load(model_name, device=device)
    model.eval()
    return model, preprocess, clip


@torch.no_grad()
def encode_images(
    image_paths: List[str],
    model,
    preprocess,
    device: str = "cuda",
    batch_size: int = 64,
) -> np.ndarray:
    """Encode to L2 normalized features, return [N, D] numpy array."""
    features: List[torch.Tensor] = []
    total = len(image_paths)
    for i in range(0, total, batch_size):
        batch_paths = image_paths[i : i + batch_size]
        images = []
        for p in batch_paths:
            img = Image.open(p).convert("RGB")
            images.append(preprocess(img))
        if len(images) == 0:
            continue
        batch = torch.stack(images, dim=0).to(device)
        feats = model.encode_image(batch)
        feats = feats / feats.norm(dim=-1, keepdim=True)
        features.append(feats.cpu())

    if not features:
        return np.zeros((0, 0), dtype=np.float32)

    feats_all = torch.cat(features, dim=0)
    return feats_all.numpy().astype(np.float32)


def save_arrays(features: np.ndarray, labels: np.ndarray, out_dir: str, prefix: str = "clip_output") -> Tuple[str, str]:
    os.makedirs(out_dir, exist_ok=True)
    f_path = os.path.join(out_dir, f"{prefix}_features.npy")
    l_path = os.path.join(out_dir, f"{prefix}_labels.npy")
    np.save(f_path, features)
    np.save(l_path, labels)
    return f_path, l_path


def main():
    parser = argparse.ArgumentParser(description="Batch encode images to CLIP features and visualize")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--images_dir", type=str, help="Root directory containing images (supports recursion)")
    group.add_argument("--images_list", type=str, help="Txt file containing full image paths (one path per line)")

    parser.add_argument("--label_mode", type=str, default="subdir", choices=["subdir", "single"], help="Label generation mode: subdir=group by parent directory name, single=all same class")
    parser.add_argument("--single_class_label", type=int, default=0, help="Class ID when label_mode=single")

    parser.add_argument("--model", type=str, default="ViT-B/32", help="CLIP model name, e.g. ViT-B/32, ViT-B/16")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--batch_size", type=int, default=64)

    parser.add_argument("--output_dir", type=str, default=".", help="Output directory (default current directory)")
    parser.add_argument("--output_prefix", type=str, default="clip_output", help="Output prefix, default clip_output")
    parser.add_argument("--save_class_names", action="store_true", help="If using subdir grouping, save class name mapping to JSON")

    parser.add_argument("--visualize", action="store_true", help="Call visualization immediately after encoding")
    args = parser.parse_args()

    # Collect images
    if args.images_dir:
        image_paths = list_images(args.images_dir)
    else:
        image_paths = read_list_file(args.images_list)

    if len(image_paths) == 0:
        raise RuntimeError("No images found in given input")

    print(f"Found {len(image_paths)} images")

    # Pre-filter: skip all black or white images (and corrupted/unreadable images)
    filtered_paths: List[str] = []
    skipped_paths: List[str] = []
    for p in image_paths:
        try:
            with Image.open(p) as img:
                if is_all_black_or_white(img.convert("RGB")):
                    skipped_paths.append(p)
                else:
                    filtered_paths.append(p)
        except Exception as e:
            print(f"Skipping corrupted or unreadable image: {p}, reason: {e}")
            skipped_paths.append(p)

    if skipped_paths:
        print(f"Skipped all black/white/unreadable images: {len(skipped_paths)} images")

    image_paths = filtered_paths
    if len(image_paths) == 0:
        raise RuntimeError("No usable images after filtering")
    print(f"Valid images for encoding: {len(image_paths)} images")

    # Build labels
    labels, class_names = build_labels(image_paths, mode=args.label_mode, single_class_label=args.single_class_label)

    if class_names is not None and args.save_class_names:
        # Save label->name mapping
        mapping = {int(i): name for i, name in enumerate(class_names)}
        os.makedirs(args.output_dir, exist_ok=True)
        save_path = os.path.join(args.output_dir, f"{args.output_prefix}_class_names.json")
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(mapping, f, ensure_ascii=False, indent=2)

    # Load CLIP
    model, preprocess, _ = load_clip(args.model, args.device)

    # Encode
    features = encode_images(image_paths, model, preprocess, device=args.device, batch_size=args.batch_size)

    # Save
    feat_path, label_path = save_arrays(features, labels, args.output_dir, prefix=args.output_prefix)
    print(f"Saved features: {feat_path}")
    print(f"Saved labels: {label_path}")

    # Optional: visualization
    if args.visualize:
        try:
            # Dynamically load by file path to avoid naming conflicts with third-party clip package
            vis_path = os.path.join(os.path.dirname(__file__), 'visualize_clip_features.py')
            spec = importlib.util.spec_from_file_location('local_visualize_clip_features', vis_path)
            if spec is None or spec.loader is None:
                raise RuntimeError(f'Cannot load visualization script: {vis_path}')
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)
            feats, labs = mod.load_data(feat_path, label_path)
            mod.print_label_statistics(labs)
            # 1) Local similarity (t-SNE)
            mod.visualize_features(feats, labs, class_names=class_names)
            # 2) Global structure (PCA)
            mod.visualize_pca_overview(feats, labs, class_names=class_names)
            # 3) Overall differences (similarity heatmap)
            mod.visualize_similarity_heatmap(feats)
        except Exception as e:
            print(f"Visualization failed: {e}")


if __name__ == "__main__":
    main()


