import argparse
from typing import Dict

import torch
import torch.nn.functional as F


def calculate_ood_scores(model, source_val_loader, device: str = "cpu") -> Dict[str, float]:
    """
    Compute MSP and Energy scores of model on source_val_loader.
    Returns dict with keys: msp_score, energy_score
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    # Unwrap DataParallel if necessary and ensure correct device
    if isinstance(model, torch.nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()

    msp_vals = []
    energy_vals = []
    with torch.no_grad():
        for images, _ in source_val_loader:
            images = images.to(device)
            logits = model(images)
            probs = F.softmax(logits, dim=1)
            msp = torch.max(probs, dim=1)[0]
            energy = torch.logsumexp(logits, dim=1)
            msp_vals.append(msp)
            energy_vals.append(energy)

    if len(msp_vals) == 0:
        return {"msp_score": 0.0, "energy_score": 0.0}

    msp_cat = torch.cat(msp_vals)
    energy_cat = torch.cat(energy_vals)
    return {"msp_score": float(msp_cat.mean().item()), "energy_score": float(energy_cat.mean().item())}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="OOD scores sanity check")
    args = parser.parse_args()
    print("This module provides calculate_ood_scores(). Implemented in Part 4.")
