import argparse
import os
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn

try:
    from tqdm.auto import tqdm
except Exception:  # pragma: no cover
    def tqdm(x, *args, **kwargs):
        return x

def _get_resnet_layers(model: nn.Module) -> Dict[str, nn.Module]:
    if isinstance(model, nn.DataParallel):
        model = model.module
    layers = {}
    # Standard torchvision ResNet-50 components
    layers["conv1"] = model.conv1
    layers["bn1"] = model.bn1
    layers["relu"] = model.relu
    layers["maxpool"] = model.maxpool
    layers["layer1"] = model.layer1
    layers["layer2"] = model.layer2
    layers["layer3"] = model.layer3
    layers["layer4"] = model.layer4
    layers["avgpool"] = model.avgpool
    return layers


@torch.no_grad()
def _extract_layer_features(model: nn.Module, loader, layer_name: str, device: str) -> torch.Tensor:
    """Extract flattened features at a given ResNet layer.

    layer_name in {layer1, layer2, layer3, layer4, avgpool}.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(model, nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()

    # Build a forward up to selected layer, then flatten
    layers = _get_resnet_layers(model)
    if layer_name not in layers:
        raise ValueError(f"Unknown layer {layer_name}")

    # Manually execute the forward pass to tap features at layer_name
    feats: List[torch.Tensor] = []
    for images, _ in loader:
        x = images.to(device)
        # Stem
        x = layers["conv1"](x)
        x = layers["bn1"](x)
        x = layers["relu"](x)
        x = layers["maxpool"](x)
        # Residual blocks
        x = layers["layer1"](x)
        if layer_name == "layer1":
            f = x
        else:
            x = layers["layer2"](x)
            if layer_name == "layer2":
                f = x
            else:
                x = layers["layer3"](x)
                if layer_name == "layer3":
                    f = x
                else:
                    x = layers["layer4"](x)
                    if layer_name == "layer4":
                        f = x
                    else:
                        # avgpool
                        x = layers["avgpool"](x)
                        f = torch.flatten(x, 1)
        if f.dim() > 2:
            f = torch.flatten(f, 1)
        # Move features to CPU immediately to avoid GPU memory accumulation
        feats.append(f.detach().cpu())
        # Proactively free GPU intermediates
        del f, x
        torch.cuda.empty_cache()
    return torch.cat(feats, dim=0) if feats else torch.empty(0)


def _compute_sinkhorn_w1(X: torch.Tensor, Y: torch.Tensor) -> float:
    try:
        from geomloss import SamplesLoss  # type: ignore

        loss_fn = SamplesLoss("sinkhorn", p=1, blur=0.05)
        return float(loss_fn(X, Y).item())
    except Exception:
        try:
            import ot  # type: ignore

            Xc = X.detach().cpu()
            Yc = Y.detach().cpu()
            M = torch.cdist(Xc, Yc, p=2) ** 2
            a = torch.full((Xc.size(0),), 1.0 / Xc.size(0), dtype=torch.float64)
            b = torch.full((Yc.size(0),), 1.0 / Yc.size(0), dtype=torch.float64)
            reg = 0.05
            val = ot.sinkhorn2(a.numpy(), b.numpy(), M.double().numpy(), reg)[0]
            return float(val ** 0.5)
        except Exception:
            with torch.no_grad():
                # Use chunked mean cdist to avoid large memory spikes
                bs = 1024
                total = 0.0
                count = 0
                for i in range(0, X.size(0), bs):
                    x = X[i:min(i+bs, X.size(0))]
                    D = torch.cdist(x, Y, p=2)
                    total += D.sum().item()
                    count += D.numel()
                return float(total / max(count, 1))


def _rbf_kernel(x: torch.Tensor, y: torch.Tensor, sigma2: float) -> torch.Tensor:
    if x.is_cuda:
        x = x.detach().cpu()
    if y.is_cuda:
        y = y.detach().cpu()
    x = x.contiguous()
    y = y.contiguous()
    x_norm = (x * x).sum(dim=1, keepdim=True)
    y_norm = (y * y).sum(dim=1, keepdim=True).t()
    cross = x @ y.t()
    dists2 = torch.clamp(x_norm + y_norm - 2.0 * cross, min=0.0)
    return torch.exp(-dists2 / (2.0 * sigma2))


def _median_heuristic_sigma2(Z: torch.Tensor, max_samples: int = 2000) -> float:
    if Z.size(0) > max_samples:
        idx = torch.randperm(Z.size(0), device=Z.device)[:max_samples]
        Z = Z[idx]
    with torch.no_grad():
        dists = torch.pdist(Z, p=2)
        if dists.numel() == 0:
            return 1.0
        median = torch.median(dists)
        sigma2 = (median.item() ** 2) if median.item() > 0 else 1.0
    return sigma2


def _mmd_unbiased(X: torch.Tensor, Y: torch.Tensor) -> float:
    if X.numel() == 0 or Y.numel() == 0:
        return 0.0
    # Center
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)
    if X.is_cuda:
        X = X.cpu()
    if Y.is_cuda:
        Y = Y.cpu()
    max_samples = 1000
    if X.size(0) > max_samples:
        idx = torch.randperm(X.size(0), device=X.device)[:max_samples]
        X = X[idx]
    if Y.size(0) > max_samples:
        idx = torch.randperm(Y.size(0), device=Y.device)[:max_samples]
        Y = Y[idx]
    Z = torch.cat([X, Y], dim=0)
    sigma2 = _median_heuristic_sigma2(Z)
    chunk = 500
    n, m = X.size(0), Y.size(0)
    Kxx_sum = 0.0
    Kxx_diag_sum = 0.0
    for i in range(0, n, chunk):
        for j in range(0, n, chunk):
            K = _rbf_kernel(X[i:min(i+chunk,n)], X[j:min(j+chunk,n)], sigma2)
            if i == j:
                Kxx_diag_sum += torch.diagonal(K).sum()
            Kxx_sum += K.sum()
    Kyy_sum = 0.0
    Kyy_diag_sum = 0.0
    for i in range(0, m, chunk):
        for j in range(0, m, chunk):
            K = _rbf_kernel(Y[i:min(i+chunk,m)], Y[j:min(j+chunk,m)], sigma2)
            if i == j:
                Kyy_diag_sum += torch.diagonal(K).sum()
            Kyy_sum += K.sum()
    Kxy_sum = 0.0
    for i in range(0, n, chunk):
        for j in range(0, m, chunk):
            K = _rbf_kernel(X[i:min(i+chunk,n)], Y[j:min(j+chunk,m)], sigma2)
            Kxy_sum += K.sum()
    if n < 2 or m < 2:
        return 0.0
    mmd2 = (
        (Kxx_sum - Kxx_diag_sum) / (n * (n - 1))
        + (Kyy_sum - Kyy_diag_sum) / (m * (m - 1))
        - 2.0 * Kxy_sum / (n * m)
    )
    return float(torch.clamp(mmd2, min=0.0).item())


def _energy_distance(X: torch.Tensor, Y: torch.Tensor) -> float:
    if X.numel() == 0 or Y.numel() == 0:
        return 0.0
    if X.is_cuda:
        X = X.cpu()
    if Y.is_cuda:
        Y = Y.cpu()
    max_samples = 2000
    if X.size(0) > max_samples:
        X = X[torch.randperm(X.size(0))[:max_samples]]
    if Y.size(0) > max_samples:
        Y = Y[torch.randperm(Y.size(0))[:max_samples]]
    n = X.size(0)
    m = Y.size(0)
    # Compute pairwise distances in chunks for memory safety
    def mean_pairwise(U: torch.Tensor, V: torch.Tensor) -> float:
        bs = 512
        total = 0.0
        count = 0
        for i in range(0, U.size(0), bs):
            u = U[i:min(i+bs, U.size(0))]
            D = torch.cdist(u, V, p=2)
            total += D.sum().item()
            count += D.numel()
        return total / max(count, 1)
    term1 = 2.0 * mean_pairwise(X, Y)
    term2 = mean_pairwise(X, X)
    term3 = mean_pairwise(Y, Y)
    return float(term1 - term2 - term3)


def _sliced_w2(X: torch.Tensor, Y: torch.Tensor, num_projections: int = 64) -> float:
    if X.numel() == 0 or Y.numel() == 0:
        return 0.0
    if X.is_cuda:
        X = X.cpu()
    if Y.is_cuda:
        Y = Y.cpu()
    d = X.size(1)
    k = num_projections
    # Random directions on unit sphere
    R = torch.randn(d, k)
    R = R / torch.norm(R, dim=0, keepdim=True).clamp(min=1e-8)
    proj_X = X @ R  # (n, k)
    proj_Y = Y @ R  # (m, k)
    vals = []
    for j in range(k):
        xj = torch.sort(proj_X[:, j])[0]
        yj = torch.sort(proj_Y[:, j])[0]
        # Match by quantiles; use min length
        n = min(xj.numel(), yj.numel())
        if n == 0:
            continue
        diff = xj[:n] - yj[:n]
        vals.append(torch.mean(diff * diff))
    if not vals:
        return 0.0
    return float(torch.mean(torch.stack(vals)).item())


def compute_shift_metrics(
    model_q: nn.Module,
    source_val_loader,
    target_train_loader,
    device: str = "cpu",
    layers: List[str] = None,
) -> Dict[str, float]:
    """Compute multi-layer shift metrics between source and target features.

    Returns a dict with keys like:
      - w1_<layer>
      - mmd_<layer>
      - energy_<layer>
      - sw2_<layer>
    """
    if layers is None:
        layers = ["layer2", "layer3", "avgpool"]
    out: Dict[str, float] = {}
    # Optional sample cap for speed
    max_samples = int(os.environ.get("trace_MAX_SHIFT_SAMPLES", "0"))
    for layer in tqdm(layers, desc="[Shift] Layers", unit="layer", leave=False):
        X = _extract_layer_features(model_q, source_val_loader, layer, device)
        Y = _extract_layer_features(model_q, target_train_loader, layer, device)
        # Optional cap and move to GPU if available for faster pairwise ops
        if max_samples and X.size(0) > max_samples:
            X = X[:max_samples]
        if max_samples and Y.size(0) > max_samples:
            Y = Y[:max_samples]
        if torch.cuda.is_available():
            X = X.cuda(non_blocking=True)
            Y = Y.cuda(non_blocking=True)
        # Do not center for W1 (translation-sensitive). Center for MMD/Energy/SW2.
        Xc = X
        Yc = Y
        out[f"w1_{layer}"] = _compute_sinkhorn_w1(X, Y) if (X.numel() and Y.numel()) else 0.0
        Xm = X - X.mean(dim=0, keepdim=True) if X.numel() else X
        Ym = Y - Y.mean(dim=0, keepdim=True) if Y.numel() else Y
        out[f"mmd_{layer}"] = _mmd_unbiased(Xm, Ym)
        out[f"energy_{layer}"] = _energy_distance(Xm, Ym)
        out[f"sw2_{layer}"] = _sliced_w2(Xm, Ym)
    return out


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Shift metrics sanity check")
    args = parser.parse_args()
    print("This module provides compute_shift_metrics().")



