#!/usr/bin/env python3
"""
Hard negative miner for (image, text description) datasets using a SigLIP model.

Example usage:
    python mine_negatives.py \
        --data train.jsonl \
        --output hard_negatives.jsonl \
        --model google/siglip2-so400m-patch16-512 \
        --num_negatives 5

Input dataset format (one JSON line per example):
    {"image_path": "path/to/img1.jpg", "text": "A dog playing in the park."}

Output file (JSONL):
    {
        "image_path": "path/to/img1.jpg",
        "pos_text": "A dog playing in the park.",
        "neg_texts": [
            "A small terrier runs across green grass.",
            ...  # total == --num_negatives
        ],
        "neg_scores": [0.83, ...]  # cosine‑similarity scores, same length as neg_texts
    }
"""

import argparse
import json
from pathlib import Path
from typing import List

import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor


def parse_args() -> argparse.Namespace:
    """CLI argument parsing"""
    parser = argparse.ArgumentParser(
        description="Mine hard negatives from an (image, text) dataset using a CLIP‑like model.")

    parser.add_argument("--data", type=Path, required=True,
                        help="Path to input JSONL file containing image_path and text fields.")
    parser.add_argument("--output", type=Path, default="hard_negatives.jsonl",
                        help="Where to save the mined hard negatives (JSONL).")
    parser.add_argument("--model", type=str,
                        default="google/siglip2-so400m-patch16-512",
                        help="Hugging Face model checkpoint to use. Should expose get_image_features/ get_text_features.")
    parser.add_argument("--num_negatives", type=int, default=5,
                        help="Number of hard negatives to mine for each (image, text) example.")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Mini‑batch size for embedding. Increase for more speed if you have GPU memory.")
    return parser.parse_args()


@torch.no_grad()
def embed_texts(texts: List[str], processor, model, device: torch.device, batch_size: int = 64) -> torch.Tensor:
    """Embed a list of texts and return a [N, D] tensor on `device`."""
    all_embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = processor(text=batch, padding=True, truncation=True, return_tensors="pt").to(device)
        feats = model.get_text_features(**inputs)
        all_embs.append(feats)
    return torch.cat(all_embs, dim=0).float()


@torch.no_grad()
def embed_images(paths: List[Path], processor, model, device: torch.device, batch_size: int = 32) -> torch.Tensor:
    """Embed a list of image paths and return a [N, D] tensor on `device`."""
    all_embs = []
    for i in range(0, len(paths), batch_size):
        imgs = [Image.open(p).convert("RGB") for p in paths[i:i + batch_size]]
        inputs = processor(images=imgs, return_tensors="pt").to(device)
        feats = model.get_image_features(**inputs)
        all_embs.append(feats)
    return torch.cat(all_embs, dim=0).float()


def main() -> None:
    args = parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) Load model & processor
    print(f"Loading model {args.model} on {device}…")
    processor = AutoProcessor.from_pretrained(args.model)
    model = AutoModel.from_pretrained(
        args.model,
        torch_dtype=(torch.float16 if device.type == "cuda" else torch.float32))
    _ = model.to(device).eval()

    # 2) Load dataset
    print("Reading dataset…")
    examples = []
    with args.data.open("r", encoding="utf‑8") as fp:
        for ln in fp:
            item = json.loads(ln)
            examples.append({
                "image_path": Path(item["image_path"]),
                "text": item["text"],
            })

    if not examples:
        raise RuntimeError("Dataset appears empty – nothing to mine.")

    # 3) Pre‑compute text embeddings once (much faster)
    texts = [e["text"] for e in examples]
    print("Embedding all texts…")
    text_embs = embed_texts(texts, processor, model, device, batch_size=args.batch_size)
    text_embs = torch.nn.functional.normalize(text_embs, dim=-1)

    # 4) Iterate over images, mine hardest negatives
    print("Mining hard negatives…")
    with args.output.open("w", encoding="utf‑8") as out_fp:
        for idx, ex in enumerate(tqdm(examples)):
            img_emb = embed_images([ex["image_path"]], processor, model, device, batch_size=1)
            img_emb = torch.nn.functional.normalize(img_emb, dim=-1)  # [1, D]

            similarities = (img_emb @ text_embs.T).squeeze(0)  # [N] cosine similarity
            similarities[idx] = -float("inf")  # Mask out the positive caption itself

            neg_scores, neg_indices = torch.topk(similarities, k=args.num_negatives, largest=True)
            neg_texts = [texts[i] for i in neg_indices.tolist()]

            record = {
                "image_path": str(ex["image_path"]),
                "pos_text": ex["text"],
                "neg_texts": neg_texts,
                "neg_scores": [round(s.item(), 6) for s in neg_scores]
            }
            out_fp.write(json.dumps(record, ensure_ascii=False) + "\n")

    print(f"✔ Done! Hard negatives saved to {args.output}")


if __name__ == "__main__":
    main()
