import argparse
import os
from typing import Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from tqdm.auto import tqdm
except Exception:  # pragma: no cover
    def tqdm(x, *args, **kwargs):
        return x

@torch.no_grad()
def _forward_logits(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
    if isinstance(model, nn.DataParallel):
        model = model.module
    return model(images)


@torch.no_grad()
def compute_outdisc_metrics(
    model_q: nn.Module,
    model_tilde_q: nn.Module,
    loader,
    device: str = "cpu",
) -> Dict[str, float]:
    """Compute multiple output discrepancy metrics on the target loader.

    Returns keys:
      - outdisc_l2_mean
      - outdisc_l1_mean
      - outdisc_cosine_mean (cosine distance between logits)
      - prob_kl_q_p (KL(P_Q || P_Qt))
      - prob_kl_p_q (KL(P_Qt || P_Q))
      - prob_js (Jensen-Shannon divergence)
      - disagree_rate (argmax disagree)
      - maxprob_abs_shift
      - margin_mean_shift
      - entropy_abs_shift
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(model_q, nn.DataParallel):
        model_q = model_q.module
    if isinstance(model_tilde_q, nn.DataParallel):
        model_tilde_q = model_tilde_q.module
    model_q = model_q.to(device)
    model_tilde_q = model_tilde_q.to(device)
    model_q.eval()
    model_tilde_q.eval()

    l2_vals = []
    l1_vals = []
    cos_vals = []
    kl_q_p_vals = []
    kl_p_q_vals = []
    js_vals = []
    disagree_vals = []
    maxprob_shift_vals = []
    margin_shift_vals = []
    entropy_shift_vals = []

    # Optional logit clipping for OutDisc metrics only
    clip_val_str = os.environ.get("trace_OUTDISC_LOGIT_CLIP", "10")
    try:
        clip_val = float(clip_val_str) if clip_val_str else 0.0
    except Exception:
        clip_val = 0.0

    for images, labels in tqdm(loader, desc="[OutDisc] Batches", unit="batch", leave=False):
        images = images.to(device)
        logits_q = _forward_logits(model_q, images)
        logits_t = _forward_logits(model_tilde_q, images)

        # Use clipped logits for OutDisc distances if requested
        if clip_val and clip_val > 0:
            logits_q_clip = torch.clamp(logits_q, min=-clip_val, max=clip_val)
            logits_t_clip = torch.clamp(logits_t, min=-clip_val, max=clip_val)
        else:
            logits_q_clip = logits_q
            logits_t_clip = logits_t

        # Logit-space distances
        l2_vals.append(torch.norm(logits_q_clip - logits_t_clip, p=2, dim=1))
        l1_vals.append(torch.norm(logits_q_clip - logits_t_clip, p=1, dim=1))

        # Cosine distance
        q_norm = F.normalize(logits_q_clip, p=2, dim=1)
        t_norm = F.normalize(logits_t_clip, p=2, dim=1)
        cos_sim = torch.sum(q_norm * t_norm, dim=1).clamp(min=-1.0, max=1.0)
        cos_dist = 1.0 - cos_sim
        cos_vals.append(cos_dist)

        # Probability-level
        pa = F.softmax(logits_q, dim=1).clamp(min=1e-8)
        pb = F.softmax(logits_t, dim=1).clamp(min=1e-8)
        log_pa = torch.log(pa)
        log_pb = torch.log(pb)
        kl_q_p = torch.sum(pa * (log_pa - log_pb), dim=1)
        kl_p_q = torch.sum(pb * (log_pb - log_pa), dim=1)
        m = 0.5 * (pa + pb)
        log_m = torch.log(m)
        js = 0.5 * torch.sum(pa * (log_pa - log_m), dim=1) + 0.5 * torch.sum(pb * (log_pb - log_m), dim=1)
        kl_q_p_vals.append(kl_q_p)
        kl_p_q_vals.append(kl_p_q)
        js_vals.append(js)

        # Disagreement
        preds_a = torch.argmax(pa, dim=1)
        preds_b = torch.argmax(pb, dim=1)
        disagree = (preds_a != preds_b).float()
        disagree_vals.append(disagree)

        # Confidence and margin shifts
        maxprob_a, _ = pa.max(dim=1)
        maxprob_b, _ = pb.max(dim=1)
        maxprob_shift_vals.append(torch.abs(maxprob_b - maxprob_a))

        def margin(p: torch.Tensor) -> torch.Tensor:
            top2 = torch.topk(p, k=2, dim=1).values
            return top2[:, 0] - top2[:, 1]

        margin_shift_vals.append(margin(pb) - margin(pa))

        # Entropy shift
        Ha = -torch.sum(pa * log_pa, dim=1)
        Hb = -torch.sum(pb * log_pb, dim=1)
        entropy_shift_vals.append(torch.abs(Hb - Ha))

    def _mean(vals: list) -> float:
        if not vals:
            return 0.0
        return float(torch.cat(vals, dim=0).mean().item())

    return {
        "outdisc_l2_mean": _mean(l2_vals),
        "outdisc_l1_mean": _mean(l1_vals),
        "outdisc_cosine_mean": _mean(cos_vals),
        "prob_kl_q_p": _mean(kl_q_p_vals),
        "prob_kl_p_q": _mean(kl_p_q_vals),
        "prob_js": _mean(js_vals),
        "disagree_rate": _mean(disagree_vals),
        "maxprob_abs_shift": _mean(maxprob_shift_vals),
        "margin_mean_shift": _mean(margin_shift_vals),
        "entropy_abs_shift": _mean(entropy_shift_vals),
    }


@torch.no_grad()
def compute_ece_shift(
    model_q: nn.Module,
    model_tilde_q: nn.Module,
    loader,
    device: str = "cpu",
    n_bins: int = 15,
) -> float:
    """Compute absolute difference in ECE between Q and Q~ on labeled loader.

    If labels are not available or loader is empty, returns 0.0.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(model_q, nn.DataParallel):
        model_q = model_q.module
    if isinstance(model_tilde_q, nn.DataParallel):
        model_tilde_q = model_tilde_q.module
    model_q = model_q.to(device)
    model_tilde_q = model_tilde_q.to(device)
    model_q.eval()
    model_tilde_q.eval()

    probs_a = []
    probs_b = []
    labels_all = []
    for images, labels in tqdm(loader, desc="[ECE] Batches", unit="batch", leave=False):
        images = images.to(device)
        labels_all.append(labels)
        logits_a = model_q(images)
        logits_b = model_tilde_q(images)
        probs_a.append(F.softmax(logits_a, dim=1).detach().cpu())
        probs_b.append(F.softmax(logits_b, dim=1).detach().cpu())
    if not probs_a:
        return 0.0
    P_a = torch.cat(probs_a, dim=0)
    P_b = torch.cat(probs_b, dim=0)
    y = torch.cat(labels_all, dim=0)

    def ece(probs: torch.Tensor, labels: torch.Tensor, bins: int) -> float:
        conf, pred = probs.max(dim=1)
        ece_val = 0.0
        bin_boundaries = torch.linspace(0, 1, bins + 1)
        for i in range(bins):
            mask = (conf > bin_boundaries[i]) & (conf <= bin_boundaries[i + 1])
            if mask.sum() == 0:
                continue
            acc = (pred[mask] == labels[mask]).float().mean()
            avg_conf = conf[mask].mean()
            ece_val += (mask.float().mean() * torch.abs(avg_conf - acc)).item()
        return float(ece_val)

    return abs(ece(P_a, y, n_bins) - ece(P_b, y, n_bins))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Model-change metrics sanity check")
    args = parser.parse_args()
    print("This module provides compute_outdisc_metrics() and compute_ece_shift().")



