#!/usr/bin/env python3
import torch
import torch.nn.functional as F
from torch import Tensor
from PIL.Image import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification

from typing import List, Tuple

@torch.no_grad()
def compute_validity(
    images: List[Image],
    classifier_name: str,
    label: int,
    batch_size: int,
    device: str,
    cache_dir: str
) -> Tuple[float, Tensor]:
    """
    """
    batch_size = min(batch_size, len(images))
    correct = torch.zeros(len(images))

    processor = AutoImageProcessor.from_pretrained(classifier_name, cache_dir=cache_dir)
    classifier = AutoModelForImageClassification.from_pretrained(classifier_name, cache_dir=cache_dir)
    classifier.eval()
    classifier.to(device)

    all_probs = []
    for i in range(0, len(images), batch_size):
        inputs = processor(images[i : i + batch_size], return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = classifier(**inputs)
        probs = F.softmax(outputs.logits, dim=1)
        
        if classifier_name == "ibombonato/swin-age-classifier": # 0 for young, 1 for old
            probs = torch.cat([probs[:, :3].mean(dim=1, keepdim=True), probs[:, 5:].mean(dim=1, keepdim=True)], dim=1)
        elif classifier_name == "londe33/hair_v02": # hair color (3->1 red, 0->2 black, 1->3 blond)
            probs = torch.cat([probs[:, 2:3], probs[:, 3:4], probs[:, 0:1], probs[:, 1:2]], dim=1)
            
        all_probs.append(probs)
    
    preds = torch.cat(all_probs).argmax(dim=-1).cpu()
    correct_partial = (torch.tensor(preds) == torch.tensor(label)).float()
    correct += correct_partial
    validity = correct_partial.mean().item()

    return validity, preds