import argparse
from typing import Dict, List

import torch
import torch.nn as nn


def _get_backbone(model: nn.Module) -> nn.Module:
    if isinstance(model, nn.DataParallel):
        model = model.module
    modules = list(model.children())[:-1]
    return nn.Sequential(*modules)


@torch.no_grad()
def _extract_features(backbone: nn.Module, loader, device: str) -> torch.Tensor:
    backbone.eval()
    feats: List[torch.Tensor] = []
    for images, _ in loader:
        images = images.to(device)
        f = backbone(images)
        f = torch.flatten(f, 1)
        feats.append(f.detach())
    if not feats:
        return torch.empty(0, device=device)
    return torch.cat(feats, dim=0)


@torch.no_grad()
def compute_mahalanobis(model: nn.Module, source_loader, device: str = "cpu") -> float:
    """Class-agnostic Mahalanobis: distance of source features to their mean using shared covariance.

    Returns average Mahalanobis norm; larger implies more spread/uncertainty.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    backbone = _get_backbone(model).to(device)
    X = _extract_features(backbone, source_loader, device)
    if X.numel() == 0 or X.size(0) < 2:
        return 0.0
    X = X - X.mean(dim=0, keepdim=True)
    # Shared covariance with shrinkage for stability
    cov = (X.t() @ X) / max(X.size(0) - 1, 1)
    # Add diagonal jitter
    cov = cov + 1e-3 * torch.eye(cov.size(0), device=cov.device)
    # Inverse via cholesky for stability
    try:
        L = torch.linalg.cholesky(cov)
        Linv = torch.cholesky_inverse(L)
    except Exception:
        Linv = torch.linalg.pinv(cov)
    # Mahalanobis distances
    d2 = torch.sum((X @ Linv) * X, dim=1)
    return float(torch.sqrt(torch.clamp(d2, min=0.0)).mean().item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Mahalanobis metric sanity check")
    args = parser.parse_args()
    print("This module provides compute_mahalanobis().")


