#!/usr/bin/env python
"""
A script to compute VENDI scores for images in subfolders using DINOv2 features, 
loading images sequentially (no concurrent pool).

Usage:
    python script.py --model <model_name> [--directory <path>] [--batch_size <size>]

Example:
    python script.py --model vitb14 --directory /ROOT --batch_size 512

Description:
    1. This script looks for subfolders within the specified directory.
    2. For each subfolder, it collects all .webp images.
    3. It computes DINOv2 embeddings for those images and calculates the VENDI score.
    4. The final results are saved to a JSON file named "<model_name>_vendi.json".
    5. Progress bars show:
       - One for the overall processing of subfolders.
       - A sub-progress bar for loading images.
       - Another for evaluating images in batches.

Available DINOv2 model names:
    ["vits14", "vitb14", "vitl14", "vitg14"]
"""

import os
import sys
import argparse
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tqdm import trange, tqdm
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_tensor

from vendi_score import vendi


def load(image):
    """Load a single image from a file path or PIL Image."""
    if isinstance(image, str):
        return Image.open(image).convert("RGB")
    return image


def batch_load(images, loading_func=load):
    """
    Load images sequentially (no concurrency).

    Args:
        images (list[str or PIL.Image.Image]): List of image paths or PIL Images.
        loading_func (callable): Function that loads or converts an image to a specific format.

    Returns:
        list[PIL.Image.Image]: List of loaded PIL Images.
    """
    loaded_images = []
    for img in tqdm(images, desc="Loading images", leave=False):
        loaded_images.append(loading_func(img))
    return loaded_images


class MetricRunner:
    """
    Base class to define a metric runner behavior for single and multi evals.
    Subclasses should override `eval` at a minimum.
    """
    single = False
    multi = False
    img_load_func = load

    def eval(self, images, ref_texts=None, is_ref=False):
        """
        Evaluate a batch of images. Override this in subclasses.

        Args:
            images (list[PIL.Image.Image]): A list of PIL Images.
            ref_texts (list[str], optional): Reference texts if needed.
            is_ref (bool): Indicates if these are reference images.
        """
        raise NotImplementedError()

    def eval_single(self, image, ref_text=None, is_ref=False):
        """
        Evaluate a single image or a single item batch.
        """
        if isinstance(image, (Image.Image, str)):
            image = [image]
        if isinstance(ref_text, str):
            ref_text = [ref_text]
        image = batch_load(image, self.img_load_func)

        return self.eval(image, ref_text, is_ref=is_ref)

    def eval_multi(self, images, ref_texts=None, ref_images=None, batch_size=32):
        """
        Evaluate multiple images in batches.
        If reference images are given, evaluate them as well.

        Args:
            images (list[str]): List of image paths.
            ref_texts (list[str], optional): Reference texts, if any.
            ref_images (list[str], optional): Reference image paths, if any.
            batch_size (int): Batch size for evaluation loops.
        """
        if ref_texts is None:
            ref_texts = [None] * max(
                len(images), 0 if ref_images is None else len(ref_images)
            )
        results = []
        ref_results = []

        # Evaluate images
        for i in trange(0, len(images), batch_size, desc="Evaluating images", unit="batch"):
            results.append(
                self.eval_single(
                    images[i : i + batch_size],
                    ref_texts[i : i + batch_size],
                    is_ref=False,
                )
            )

        # Evaluate reference images, if provided
        if ref_images is not None:
            for i in trange(0, len(ref_images), batch_size, desc="Evaluating reference images", unit="batch"):
                ref_results.append(
                    self.eval_single(
                        ref_images[i : i + batch_size],
                        ref_texts[i : i + batch_size],
                        is_ref=True,
                    )
                )

        return results, ref_results


def load_img(file, size):
    """Helper function to load and resize an image to a given size."""
    with Image.open(file) as img:
        return to_tensor(img.convert("RGB").resize(size, Image.BICUBIC))


class DINOv2FeatureExtractor(nn.Module):
    """
    A simple wrapper around the DINOv2 models from Facebook's PyTorch Hub.

    Attributes:
        model (nn.Module): The loaded DINOv2 model.
        size (tuple): (width, height) to which images are resized.
    """

    def __init__(self, name="vitl14", device="cuda"):
        super().__init__()
        self.model = (
            torch.hub.load("facebookresearch/dinov2", "dinov2_" + name)
            .to(device)
            .eval()
            .requires_grad_(False)
        )
        self.normalize = transforms.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
        )
        self.size = (224, 224)

    @classmethod
    def available_models(cls):
        """List available DINOv2 models."""
        return ["vits14", "vitb14", "vitl14", "vitg14"]

    def forward(self, x):
        """
        Forward pass: normalize input, run through DINOv2, and L2-normalize outputs.
        """
        if x.shape[1] == 1:
            x = torch.cat([x] * 3, dim=1)
        x = self.normalize(x)
        with torch.autocast("cuda", dtype=torch.float16):
            x = self.model(x.cuda()).float()
        x = F.normalize(x)  # L2 normalize
        return x


def cosine_sim(a, b):
    """Compute cosine similarity between two 1D vectors (NumPy arrays)."""
    return (a @ b) / (np.linalg.norm(a) * np.linalg.norm(b))


class VendiRunner(MetricRunner):
    """
    A specialized MetricRunner that uses the VENDI score on DINOv2 features.
    """
    multi = True

    def __init__(self, feature_extractor, img_size=(224, 224)):
        self.feature_extractor = feature_extractor
        self.img_size = img_size
        # Instead of concurrency-based loading, we specify the direct load_img function
        self.img_load_func = lambda x: load_img(x, size=self.img_size)
        self.features = []

    @torch.no_grad()
    def eval(self, images, ref_texts=None, is_ref=False):
        # Convert list of PIL images to a batch tensor
        images = torch.stack(images).cuda()
        features = self.feature_extractor(images)
        self.features.append(features.cpu())

    @torch.no_grad()
    def eval_multi(self, images, ref_texts=None, ref_images=None, batch_size=32):
        """
        Override eval_multi to combine all extracted features and compute the final VENDI score.
        """
        self.features = []
        super().eval_multi(images, ref_texts, ref_images, batch_size)
        # Clear CUDA cache
        torch.cuda.empty_cache()

        all_features = torch.cat(self.features, dim=0)
        normed_all_features = F.normalize(all_features)
        similarities = normed_all_features @ normed_all_features.T
        # Calculate VENDI score on the similarity matrix
        result = vendi.score_K(similarities.cpu().numpy())
        return result


def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description="Compute VENDI scores on images in subfolders using DINOv2 features."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="vitb14",
        help="Which DINOv2 model to use. Options: vits14, vitb14, vitl14, vitg14.",
    )
    parser.add_argument(
        "--directory",
        type=str,
        default="/images",
        help="Parent directory containing subfolders of images.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=512,
        help="Batch size for image feature extraction.",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Path to the output JSON file. Defaults to '<model>_vendi.json'.",
    )

    args = parser.parse_args()
    model_name = args.model
    root_dir = args.directory
    batch_size = args.batch_size

    if not args.output:
        output_file = f"{model_name}_vendi.json"
    else:
        output_file = args.output

    # Validate the model name
    if model_name not in DINOv2FeatureExtractor.available_models():
        print(
            f"[ERROR] Invalid model name '{model_name}'. "
            f"Choose from {DINOv2FeatureExtractor.available_models()}."
        )
        sys.exit(1)

    print(f"Using DINOv2 model: {model_name}")
    print(f"Searching for subfolders in directory: {root_dir}")

    # Instantiate the feature extractor and the VendiRunner
    extractor = DINOv2FeatureExtractor(model_name)
    runner = VendiRunner(extractor, (224, 224))

    # Collect valid subfolders
    PATHS = [
        os.path.join(root_dir, subfolder)
        for subfolder in os.listdir(root_dir)
        if os.path.isdir(os.path.join(root_dir, subfolder))
    ]
    if not PATHS:
        print(f"[WARNING] No subfolders found in {root_dir}. Exiting.")
        return

    results = {}

    # Top-level progress bar over subfolders
    for idx, folder_base in enumerate(
        tqdm(PATHS, desc="Processing subfolders", unit="folder")
    ):
        for subfolder in ["best", "worst"]:
            folder = f"{folder_base}/{subfolder}"
            img_files = [os.path.join(folder, i) for i in os.listdir(folder)]
            images = [i for i in img_files if i.endswith(".webp")]

            if not images:
                print(f"[WARNING] No '.webp' images found in {folder}. Skipping.")
                continue

            print(f"  -> Evaluating VENDI for {folder} with {len(images)} images.")

            # Compute vendi score for images in the folder
            result = runner.eval_multi(images, batch_size=batch_size)
            results[folder] = result

    # Print summary of results
    print("=" * 40)
    print("Summary of VENDI scores:")
    for folder, score in results.items():
        print(f"  {folder}: {score.item()}")

    # Save results to JSON
    print(f"Saving results to {output_file}")
    with open(output_file, "w") as f:
        # Convert Tensors to float if needed
        serializable_results = {
            folder: float(score.item()) for folder, score in results.items()
        }
        json.dump(serializable_results, f, indent=4)

    print("[INFO] Done. Exiting.")


if __name__ == "__main__":
    main()
