import torch
from collections import defaultdict
from utils.logger import log_to_csv

def evaluate_per_class_accuracy(model,classifier, dataloader, device=None):
    model.eval()
    classifier.eval()
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    classifier.to(device)

    correct = defaultdict(int)
    total = defaultdict(int)

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            features = model.backbone(x)
            latents = model.projector(features)
            outputs = classifier(latents)
            preds = torch.argmax(outputs, dim=1)
            for label in torch.unique(y):
                label_mask = (y == label)
                correct[label.item()] += (preds[label_mask] == y[label_mask]).sum().item()
                total[label.item()] += label_mask.sum().item()

    per_class_acc = {label: correct[label] / total[label] if total[label] > 0 else 0.0
                     for label in total}
    return per_class_acc

def log_per_class_accuracy(csv_path, model_info, acc_dict):
    row = {**model_info, **{f"acc_class_{k}": v for k, v in acc_dict.items()}}
    log_to_csv(csv_path, row)
