#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Model evaluation utilities for point cloud classification
- Overall accuracy calculation
- Per-class accuracy calculation
- Misclassified sample logging
"""
import torch
from collections import defaultdict
from tqdm import tqdm


def evaluate(model: torch.nn.Module,
             dataloader: torch.utils.data.DataLoader,
             device: torch.device) -> tuple:
    """
    Evaluate trained model on test/validation dataset
    Args:
        model: Trained SimplePointNet model
        dataloader: PyTorch DataLoader for test/validation data
        device: Computation device (cuda/cpu)
    Returns:
        Tuple of (overall_accuracy, per_class_accuracy, misclassified_samples)
        - overall_accuracy: Float (0 ~ 1)
        - per_class_accuracy: Dict {class_idx: accuracy (0 ~ 1)}
        - misclassified_samples: List of (file_path, true_label, pred_label)
    """
    # Set model to evaluation mode (disable BatchNorm/ Dropout)
    model.eval()
    total_samples = 0
    correct_predictions = 0
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    misclassified = []

    # Disable gradient computation for evaluation (speed up and save memory)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            # Unpack batch (handle both with/without object ID)
            if len(batch) == 3:
                pts, labels, fnames = batch
            else:
                pts, labels, fnames, _ = batch

            # Move data to computation device
            pts = pts.to(device)
            labels = labels.to(device)

            # Model forward pass and prediction
            logits = model(pts)
            preds = torch.argmax(logits, dim=1)

            # Update total and correct counts
            total_samples += labels.size(0)
            correct_predictions += (preds == labels).sum().item()

            # Update per-class counts and collect misclassified samples
            for p, t, fname in zip(preds.cpu().numpy(), labels.cpu().numpy(), fnames):
                class_total[int(t)] += 1
                if int(p) == int(t):
                    class_correct[int(t)] += 1
                else:
                    misclassified.append((fname, int(t), int(p)))

    # Calculate overall accuracy
    overall_acc = correct_predictions / total_samples if total_samples > 0 else 0.0

    # Calculate per-class accuracy
    per_class_acc = {}
    for cls_idx in sorted(class_total.keys()):
        per_class_acc[cls_idx] = class_correct[cls_idx] / class_total[cls_idx] if class_total[cls_idx] > 0 else 0.0

    return overall_acc, per_class_acc, misclassified