import argparse
import os
from typing import Dict, Any, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from experiments.deployment_gate.model_factory import get_resnet50
try:
    from tqdm.auto import tqdm
except Exception:  # pragma: no cover
    def tqdm(x, *args, **kwargs):
        return x

# Module-level cache to avoid recomputing W1 (constant across candidates)
_CACHED_W1_VALUE: Optional[float] = None

def _build_backbone(model: nn.Module) -> nn.Module:
    # Unwrap DataParallel if present
    if isinstance(model, nn.DataParallel):
        model = model.module
    modules = list(model.children())[:-1]  # exclude final fc
    backbone = nn.Sequential(*modules)
    return backbone


def _build_frozen_imagenet_backbone() -> nn.Module:
    """Frozen ImageNet-pretrained ResNet-50 backbone up to avgpool.

    Used for transport feature extraction to decouple shift from Q/\tilde Q training.
    """
    model = get_resnet50(pretrained=True, num_classes=1000)
    for p in model.parameters():
        p.requires_grad = False
    modules = list(model.children())[:-1]
    backbone = nn.Sequential(*modules)
    return backbone

def _extract_features(backbone: nn.Module, data_loader, device: str) -> torch.Tensor:
    backbone.eval()
    feats: List[torch.Tensor] = []
    with torch.no_grad():
        for images, _ in tqdm(data_loader, desc="[W1] Feats", unit="batch", leave=False):
            images = images.to(device)
            f = backbone(images)
            f = torch.flatten(f, 1)
            feats.append(f.detach())
    if len(feats) == 0:
        return torch.empty(0, device=device)
    return torch.cat(feats, dim=0)


def _compute_sinkhorn_w1(X: torch.Tensor, Y: torch.Tensor) -> float:
    # Prefer geomloss if available
    try:
        from geomloss import SamplesLoss  # type: ignore

        # Keep blur moderate for stability; allow override via env
        blur = float(os.environ.get("trace_SINKHORN_BLUR", "0.1"))
        max_iters = int(os.environ.get("trace_SINKHORN_ITERS", "200"))
        loss_fn = SamplesLoss("sinkhorn", p=1, blur=blur, backend="tensorized", max_iter=max_iters)
        # Move to GPU if available for speed
        dev = X.device
        if torch.cuda.is_available():
            dev = torch.device("cuda")
        Xg = X.to(dev, non_blocking=True)
        Yg = Y.to(dev, non_blocking=True)
        val = loss_fn(Xg, Yg).item()
        return float(val)
    except Exception:
        # Fallback options
        try:
            import ot  # type: ignore

            # POT expects cpu numpy for sinkhorn2; use squared euclidean cost
            Xc = X.detach().cpu()
            Yc = Y.detach().cpu()
            M = torch.cdist(Xc, Yc, p=2) ** 2
            a = torch.full((Xc.size(0),), 1.0 / Xc.size(0), dtype=torch.float64)
            b = torch.full((Yc.size(0),), 1.0 / Yc.size(0), dtype=torch.float64)
            reg = 0.05
            val = ot.sinkhorn2(a.numpy(), b.numpy(), M.double().numpy(), reg)[0]
            # Convert cost to an approximate W1 by sqrt (since cost used squared distances)
            return float(val ** 0.5)
        except Exception:
            # Last-resort proxy: mean pairwise distance (not exact W1)
            with torch.no_grad():
                return float(torch.cdist(X, Y, p=2).mean().item())


def _estimate_lipschitz_input(model: nn.Module, data_loader, device: str, quantile: float = 0.99) -> float:
    # Unwrap DP and ensure model is on the correct device
    if isinstance(model, nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()
    norms: List[torch.Tensor] = []
    # Optional cap for speed; 0 means no cap
    max_samples = int(os.environ.get("trace_MAX_LIP_SAMPLES", "0"))
    seen = 0
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)
        images.requires_grad_(True)

        logits = model(images)
        losses = F.cross_entropy(logits, labels, reduction="none")
        grad_outputs = torch.ones_like(losses)
        grads = torch.autograd.grad(
            outputs=losses,
            inputs=images,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=False,
            only_inputs=True,
            allow_unused=False,
        )[0]
        if grads is None:
            continue
        batch_norms = torch.norm(grads.view(grads.size(0), -1), p=2, dim=1)
        norms.append(batch_norms.detach())
        # Clear for safety
        images.requires_grad_(False)
        seen += images.size(0)
        if max_samples and seen >= max_samples:
            break

    if len(norms) == 0:
        return 0.0
    all_norms = torch.cat(norms, dim=0)
    q = torch.quantile(all_norms, torch.tensor(quantile, device=all_norms.device))
    return float(q.item())


def _estimate_lipschitz_multiple(
    model: nn.Module,
    data_loader,
    device: str,
    quantiles: List[float],
) -> Dict[float, float]:
    # Unwrap and compute norms once for efficiency
    if isinstance(model, nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()
    norms: List[torch.Tensor] = []
    max_samples = int(os.environ.get("trace_MAX_LIP_SAMPLES", "0"))
    seen = 0
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)
        images.requires_grad_(True)
        logits = model(images)
        losses = F.cross_entropy(logits, labels, reduction="none")
        grad_outputs = torch.ones_like(losses)
        grads = torch.autograd.grad(
            outputs=losses,
            inputs=images,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=False,
            only_inputs=True,
            allow_unused=False,
        )[0]
        if grads is None:
            continue
        batch_norms = torch.norm(grads.view(grads.size(0), -1), p=2, dim=1)
        norms.append(batch_norms.detach())
        images.requires_grad_(False)
        seen += images.size(0)
        if max_samples and seen >= max_samples:
            break
    if len(norms) == 0:
        return {q: 0.0 for q in quantiles}
    all_norms = torch.cat(norms, dim=0)
    out: Dict[float, float] = {}
    for q in quantiles:
        out[q] = float(torch.quantile(all_norms, torch.tensor(q, device=all_norms.device)).item())
    return out


def _output_discrepancy(model_q: nn.Module, model_tilde_q: nn.Module, target_loader, device: str) -> float:
    # Unwrap DP and ensure models are on the correct device
    if isinstance(model_q, nn.DataParallel):
        model_q = model_q.module
    if isinstance(model_tilde_q, nn.DataParallel):
        model_tilde_q = model_tilde_q.module
    model_q = model_q.to(device)
    model_tilde_q = model_tilde_q.to(device)

    model_q.eval()
    model_tilde_q.eval()
    diffs: List[torch.Tensor] = []
    max_samples = int(os.environ.get("trace_MAX_OUTDISC_SAMPLES", "0"))
    seen = 0
    with torch.no_grad():
        for images, _ in target_loader:
            images = images.to(device)
            logits_q = model_q(images)
            logits_t = model_tilde_q(images)
            diff = torch.norm(logits_q - logits_t, p=2, dim=1)
            diffs.append(diff)
            seen += images.size(0)
            if max_samples and seen >= max_samples:
                break
    if len(diffs) == 0:
        return 0.0
    diffs_all = torch.cat(diffs, dim=0)
    return float(diffs_all.mean().item())


def calculate_trace_bound(
    model_q: nn.Module,
    model_tilde_q: nn.Module,
    source_val_loader,
    target_train_loader,
    device: str = "cpu",
    lip_quantile: float = 0.99,
) -> Dict[str, Any]:
    """
    Compute trace/TRACE components and return a dict with:
    - bound: primary predictor (shift + output discrepancy)
    - w1_term: Sinkhorn approximation to W1 between source and target feature distributions (frozen ImageNet backbone)
    - output_dist: average L2 distance between logits of Q and \tilde{Q} on target data (eval transforms, no aug)
    - lipschitz_q: 99th percentile input-grad norm on source_val for Q
    - lipschitz_tilde_q: 99th percentile input-grad norm on target_eval for \tilde{Q}
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # Feature extractor: frozen ImageNet-pretrained backbone (decoupled from Q/\tilde Q)
    # Cache W1 across candidates if trace_CACHE_W1=1 (default)
    global _CACHED_W1_VALUE
    cache_enabled = os.environ.get("trace_CACHE_W1", "1") == "1"
    if cache_enabled and (_CACHED_W1_VALUE is not None):
        w1_term = float(_CACHED_W1_VALUE)
    else:
        backbone_q = _build_frozen_imagenet_backbone().to(device)
    # Feature distributions (stay on device; _compute_sinkhorn_w1 will move as needed)
    src_feats = _extract_features(backbone_q, source_val_loader, device)
    tgt_feats = _extract_features(backbone_q, target_train_loader, device)

    # Optional cap for W1 computation size (to avoid very slow/oom OT)
    max_w1 = int(os.environ.get("trace_MAX_W1_SAMPLES", "6000"))
    if max_w1 > 0:
        n = min(src_feats.size(0), tgt_feats.size(0), max_w1)
        if src_feats.size(0) > n:
            src_feats = src_feats[:n]
        if tgt_feats.size(0) > n:
            tgt_feats = tgt_feats[:n]
        w1_term = _compute_sinkhorn_w1(src_feats, tgt_feats) if (src_feats.numel() and tgt_feats.numel()) else 0.0
        if cache_enabled:
            _CACHED_W1_VALUE = float(w1_term)

    # Lipschitz via input-grad norms (support multiple quantiles)
    uniq_qs = sorted(list(set([0.95, 0.99, float(lip_quantile)])))
    lips_qs = _estimate_lipschitz_multiple(model_q, source_val_loader, device, quantiles=uniq_qs)
    # NOTE: estimate Lipschitz for tilde model on target evaluation split
    lips_ts = _estimate_lipschitz_multiple(model_tilde_q, target_train_loader, device, quantiles=uniq_qs)
    lips_q = lips_qs.get(float(lip_quantile), 0.0)
    lips_t = lips_ts.get(float(lip_quantile), 0.0)

    # Output discrepancy on target
    out_discrep = _output_discrepancy(model_q, model_tilde_q, target_train_loader, device)

    # Primary predictor: shift + output discrepancy
    # TRACE score variant: OutDisc + Lx(tilde_Q) * W1
    bound = float(out_discrep + lips_t * w1_term)

    return {
        "bound": bound,
        "w1_term": float(w1_term),
        "output_dist": float(out_discrep),
        "lipschitz_q": float(lips_q),
        "lipschitz_tilde_q": float(lips_t),
        "lipschitz_q_q95": float(lips_qs.get(0.95, 0.0)),
        "lipschitz_tilde_q_q95": float(lips_ts.get(0.95, 0.0)),
        "lip_quantile_used": float(lip_quantile),
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="trace metrics sanity check")
    args = parser.parse_args()
    print("This module provides calculate_trace_bound(). Implemented in Part 4.")
