#!/usr/bin/env python

import os
import sys
import argparse
from concurrent.futures import ProcessPoolExecutor
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_tensor
from vendi_score import vendi
from PIL import Image
from tqdm import trange, tqdm
import json

# Set up default concurrency
pool = ProcessPoolExecutor(max_workers=24)


def load(image):
    """Load a single image from a file path or PIL Image."""
    if isinstance(image, str):
        return Image.open(image).convert("RGB")
    return image


def batch_load(images, loading_func=load):
    """
    Load a batch of images concurrently using a shared executor pool.

    Args:
        images (list[str or Image.Image]): List of image paths or PIL Images.
        loading_func (callable): Function that loads or converts an image to a specific format.

    Returns:
        list[Image.Image]: List of loaded PIL Images.
    """
    return list(
        tqdm(
            pool.map(loading_func, images),
            total=len(images),
            leave=False,
            desc="Loading images",
        )
    )


class MetricRunner:
    """
    Base class to define a metric runner behavior for single and multi evals.
    Subclasses should override `eval` at a minimum.
    """
    single = False
    multi = False
    img_load_func = load

    def eval(self, images, ref_texts=None, is_ref=False):
        """
        Evaluate a batch of images. Override this in subclasses.

        Args:
            images (list[PIL.Image]): A list of PIL Images.
            ref_texts (list[str], optional): Reference texts if needed.
            is_ref (bool): Indicates if these are reference images.
        """
        raise NotImplementedError()

    def eval_single(self, image, ref_text=None, is_ref=False):
        """
        Evaluate a single image or a single item batch.
        """
        if isinstance(image, (Image.Image, str)):
            image = [image]
        if isinstance(ref_text, str):
            ref_text = [ref_text]
        image = batch_load(image, self.img_load_func)

        return self.eval(image, ref_text, is_ref=is_ref)

    def eval_multi(self, images, ref_texts=None, ref_images=None, batch_size=32):
        """
        Evaluate multiple images in batches.
        If reference images are given, evaluate them as well.

        Args:
            images (list[str]): List of image paths.
            ref_texts (list[str], optional): Reference texts, if any.
            ref_images (list[str], optional): Reference image paths, if any.
            batch_size (int): Batch size for evaluation loops.
        """
        if ref_texts is None:
            ref_texts = [None] * max(
                len(images), 0 if ref_images is None else len(ref_images)
            )
        results = []
        ref_results = []

        # Evaluate images
        for i in trange(0, len(images), batch_size, desc="Evaluating images"):
            results.append(
                self.eval_single(
                    images[i : i + batch_size],
                    ref_texts[i : i + batch_size],
                    is_ref=False,
                )
            )

        # Evaluate reference images, if provided
        if ref_images is not None:
            for i in trange(0, len(ref_images), batch_size, desc="Evaluating reference images"):
                ref_results.append(
                    self.eval_single(
                        ref_images[i : i + batch_size],
                        ref_texts[i : i + batch_size],
                        is_ref=True,
                    )
                )

        return results, ref_results


def load_img(file, size):
    """Helper function to load and resize an image to a given size."""
    with Image.open(file) as img:
        return to_tensor(img.convert("RGB").resize(size, Image.BICUBIC))


class DINOv2FeatureExtractor(nn.Module):
    """
    A simple wrapper around the DINOv2 models from Facebook's PyTorch Hub.

    Attributes:
        model (nn.Module): The loaded DINOv2 model.
        size (tuple): (width, height) to which images are resized.
    """

    def __init__(self, name="vitl14", device="cuda"):
        super().__init__()
        self.model = (
            torch.hub.load("facebookresearch/dinov2", "dinov2_" + name)
            .to(device)
            .eval()
            .requires_grad_(False)
        )
        self.normalize = transforms.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
        )
        self.size = (224, 224)

    @classmethod
    def available_models(cls):
        """List available DINOv2 models."""
        return ["vits14", "vitb14", "vitl14", "vitg14"]

    def forward(self, x):
        """
        Forward pass: normalize input, run through DINOv2, and L2-normalize outputs.
        """
        if x.shape[1] == 1:
            x = torch.cat([x] * 3, dim=1)
        x = self.normalize(x)
        with torch.autocast("cuda", dtype=torch.float16):
            x = self.model(x.cuda()).float()
        x = F.normalize(x)  # L2 normalize
        return x


def cosine_sim(a, b):
    """Compute cosine similarity between two 1D vectors (arrays)."""
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


class VendiRunner(MetricRunner):
    """
    A specialized MetricRunner that uses the VENDI score on DINOv2 features.
    """
    multi = True

    def __init__(self, feature_extractor, img_size=(224, 224)):
        self.feature_extractor = feature_extractor
        self.img_size = img_size
        self.img_load_func = partial(load_img, size=self.img_size)
        self.features = []

    @torch.no_grad()
    def eval(self, images, ref_texts=None, is_ref=False):
        # Convert list of PIL images to a batch tensor
        images = torch.stack(images).cuda()
        features = self.feature_extractor(images)
        self.features.append(features.cpu())

    @torch.no_grad()
    def eval_multi(self, images, ref_texts=None, ref_images=None, batch_size=32):
        """
        Override eval_multi to combine all extracted features and compute the final VENDI score.
        """
        self.features = []
        super().eval_multi(images, ref_texts, ref_images, batch_size)
        # Clear CUDA cache
        torch.cuda.empty_cache()

        all_features = torch.cat(self.features, dim=0)
        normed_all_features = F.normalize(all_features)
        similarities = normed_all_features @ normed_all_features.T
        # Calculate VENDI score on the similarity matrix
        result = vendi.score_K(similarities.cpu().numpy())
        return result


def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description="Compute VENDI scores on images in subfolders using DINOv2 features."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="vitb14",
        help="Which DINOv2 model to use. Options: vits14, vitb14, vitl14, vitg14.",
    )
    parser.add_argument(
        "--directory",
        type=str,
        default="/ROOT",
        help="Parent directory containing subfolders of images.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=512,
        help="Batch size for image feature extraction.",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Path to the output JSON file. Defaults to '<model>_vendi.json'.",
    )

    args = parser.parse_args()
    model_name = args.model
    root_dir = args.directory
    batch_size = args.batch_size

    if not args.output:
        output_file = f"{model_name}_vendi.json"
    else:
        output_file = args.output

    # Validate the model name
    if model_name not in DINOv2FeatureExtractor.available_models():
        print(
            f"[ERROR] Invalid model name '{model_name}'. "
            f"Choose from {DINOv2FeatureExtractor.available_models()}."
        )
        sys.exit(1)

    print(f"Using DINOv2 model: {model_name}")
    print(f"Searching for subfolders in directory: {root_dir}")

    # Instantiate the feature extractor and the VendiRunner
    extractor = DINOv2FeatureExtractor(model_name)
    runner = VendiRunner(extractor, (224, 224))

    # Collect valid subfolders
    PATHS = [
        os.path.join(root_dir, subfolder)
        for subfolder in os.listdir(root_dir)
        if os.path.isdir(os.path.join(root_dir, subfolder))
    ]
    if not PATHS:
        print(f"[WARNING] No subfolders found in {root_dir}. Exiting.")
        return

    results = {}

    # Top-level progress bar over subfolders
    for idx, folder in enumerate(
        tqdm(PATHS, desc="Processing subfolders", unit="folder")
    ):
        img_files = [os.path.join(folder, i) for i in os.listdir(folder)]
        images = [i for i in img_files if i.endswith(".webp")]

        if not images:
            print(f"[WARNING] No '.webp' images found in {folder}. Skipping.")
            continue

        print(f"  -> Evaluating VENDI for {folder} with {len(images)} images.")

        # Compute vendi score for images in the folder
        result = runner.eval_multi(images, batch_size=batch_size)
        results[folder] = result

    # Print results in summary
    print("=" * 40)
    print("Summary of VENDI scores:")
    for folder, score in results.items():
        print(f"  {folder}: {score.item()}")

    # Save results to JSON
    print(f"Saving results to {output_file}")
    with open(output_file, "w") as f:
        # Convert Tensors to float if needed
        serializable_results = {
            folder: float(score.item()) for folder, score in results.items()
        }
        json.dump(serializable_results, f, indent=4)

    print("[INFO] Done. Exiting.")


if __name__ == "__main__":
    main()
