import argparse
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F


@torch.no_grad()
def compute_entropy(model: nn.Module, loader, device: str = "cpu") -> float:
    """Mean predictive entropy on a loader."""
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(model, nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()
    entropies = []
    for images, _ in loader:
        images = images.to(device)
        logits = model(images)
        probs = F.softmax(logits, dim=1)
        entropy = -torch.sum(probs * torch.clamp(probs, min=1e-12).log(), dim=1)
        entropies.append(entropy)
    if not entropies:
        return 0.0
    return float(torch.cat(entropies, dim=0).mean().item())


@torch.no_grad()
def compute_symmetric_kl(model_a: nn.Module, model_b: nn.Module, loader, device: str = "cpu") -> float:
    """Average symmetric KL divergence between two models' predictive distributions on a loader."""
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(model_a, nn.DataParallel):
        model_a = model_a.module
    if isinstance(model_b, nn.DataParallel):
        model_b = model_b.module
    model_a = model_a.to(device)
    model_b = model_b.to(device)
    model_a.eval()
    model_b.eval()
    vals = []
    for images, _ in loader:
        images = images.to(device)
        logits_a = model_a(images)
        logits_b = model_b(images)
        pa = F.softmax(logits_a, dim=1)
        pb = F.softmax(logits_b, dim=1)
        kl_ab = torch.sum(pa * (torch.clamp(pa, min=1e-12).log() - torch.clamp(pb, min=1e-12).log()), dim=1)
        kl_ba = torch.sum(pb * (torch.clamp(pb, min=1e-12).log() - torch.clamp(pa, min=1e-12).log()), dim=1)
        vals.append(0.5 * (kl_ab + kl_ba))
    if not vals:
        return 0.0
    return float(torch.cat(vals, dim=0).mean().item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Disagreement metrics sanity check")
    args = parser.parse_args()
    print("This module provides compute_entropy() and compute_symmetric_kl().")


