from typing import Dict, Iterator, Tuple

import torch
from PIL import Image

from sae.hooked_vit import HookedVisionTransformer


class ImageProcessor:
    @staticmethod
    def process_model_inputs(
        batch: Dict,
        vit: HookedVisionTransformer,
        device: str,
        process_labels: bool = False,
        image_key: str = "image",
    ) -> torch.Tensor:
        """Process input images through the ViT processor."""
        if isinstance(batch[image_key][0], Image.Image):
            images = batch[image_key]
        else:
            images = [Image.open(image_path).convert("RGB") for image_path in batch[image_key]]

        if process_labels:
            labels = [f"A photo of a {label}" for label in batch["label"]]
            return vit.processor(images=images, text=labels, return_tensors="pt", padding=True).to(device)

        return vit.processor(images=images, text="", return_tensors="pt", padding=True).to(device)

    @staticmethod
    def get_sae_activations(
        sae, vit: HookedVisionTransformer, inputs: dict, block_layer, module_name, class_token, get_mean=True
    ) -> torch.Tensor:
        """Extract activations from a specific layer of the vision transformer vitt."""
        hook_location = (block_layer, module_name)

        _, cache = vit.run_with_cache([hook_location], **inputs)
        activations = cache[hook_location]

        batch_size = inputs["pixel_values"].shape[0]
        if activations.shape[0] != batch_size:
            activations = activations.transpose(0, 1)

        if class_token:
            activations = activations[0, :, :]

        _, cache = sae.run_with_cache(activations)
        sae_act = cache["hook_hidden_post"]
        if get_mean:
            sae_act = sae_act.mean(1)

        return sae_act


class BatchIterator:
    """Iterator for batching dataset."""

    @staticmethod
    def get_batches(
        dataset: Dict, batch_size: int, image_key: str = "image", max_sample=None
    ) -> Iterator[Tuple[int, Dict]]:
        """Create batch iterator from dataset."""

        try:
            num_samples = dataset.num_rows
        except:
            num_samples = len(dataset[image_key])

        if max_sample is not None:
            num_samples = min(num_samples, max_sample)

        indices = range(0, num_samples, batch_size)

        for start_idx in indices:
            end_idx = min(start_idx + batch_size, num_samples)
            try:
                batch = dataset[start_idx:end_idx]
            except (TypeError, KeyError):
                batch = {
                    image_key: dataset[image_key][start_idx:end_idx],
                }
            yield start_idx, end_idx, batch
