import argparse
import os
from typing import Callable, List

import torch


def _pairwise_dists2_cpu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Compute squared pairwise distances on CPU in a memory-aware way."""
    if x.is_cuda:
        x = x.detach().cpu()
    if y.is_cuda:
        y = y.detach().cpu()
    x = x.contiguous()
    y = y.contiguous()
    x_norm = (x * x).sum(dim=1, keepdim=True)  # (n, 1)
    y_norm = (y * y).sum(dim=1, keepdim=True).t()  # (1, m)
    cross = x @ y.t()  # (n, m)
    return torch.clamp(x_norm + y_norm - 2.0 * cross, min=0.0)


def _rbf_kernel_from_dists2(dists2: torch.Tensor, sigma2: float) -> torch.Tensor:
    return torch.exp(-dists2 / (2.0 * sigma2))


def _parse_sigma_multipliers(default: List[float] = None) -> List[float]:
    """Parse comma-separated multipliers from env var trace_MMD_SIGMA_MULTS."""
    if default is None:
        default = []
    env_val = os.environ.get("trace_MMD_SIGMA_MULTS", "").strip()
    if not env_val:
        return default
    try:
        parts = [float(x) for x in env_val.split(",") if x]
        return [p for p in parts if p > 0]
    except Exception:
        return default


def _rbf_kernel_mixture(x: torch.Tensor, y: torch.Tensor, base_sigma2: float, mults: List[float]) -> torch.Tensor:
    """Mixture of RBF kernels with bandwidths = mult * base_sigma2.

    If mults is empty, falls back to single RBF with base_sigma2.
    """
    dists2 = _pairwise_dists2_cpu(x, y)
    if not mults:
        return _rbf_kernel_from_dists2(dists2, base_sigma2)
    K = None
    for m in mults:
        K_m = _rbf_kernel_from_dists2(dists2, base_sigma2 * m)
        K = K_m if K is None else (K + K_m)
    # Average the mixture to keep scale stable
    return K / float(len(mults))


def _median_heuristic_sigma2(Z: torch.Tensor, max_samples: int = 2000) -> float:
    if Z.size(0) > max_samples:
        idx = torch.randperm(Z.size(0), device=Z.device)[:max_samples]
        Z = Z[idx]
    with torch.no_grad():
        # Use pairwise distances upper triangle
        dists = torch.pdist(Z, p=2)
        if dists.numel() == 0:
            return 1.0
        median = torch.median(dists)
        sigma2 = (median.item() ** 2) if median.item() > 0 else 1.0
    return sigma2


def calculate_mmd(source_val_loader, target_train_loader, feature_extractor: Callable, device: str = "cpu") -> float:
    """
    Compute MMD with RBF kernel between source and target features.
    Returns scalar MMD value.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        # Extract on the specified device but move to CPU for pairwise ops to reduce GPU memory pressure
        X = feature_extractor(source_val_loader).to(device)
        Y = feature_extractor(target_train_loader).to(device)

    if X.numel() == 0 or Y.numel() == 0:
        return 0.0

    # Center features
    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)

    # Move to CPU for kernel computations
    if X.is_cuda:
        X = X.cpu()
    if Y.is_cuda:
        Y = Y.cpu()

    # Memory-efficient MMD computation with subsampling
    max_samples = 1000  # Limit samples to avoid OOM
    if X.size(0) > max_samples:
        idx = torch.randperm(X.size(0), device=X.device)[:max_samples]
        X = X[idx]
    if Y.size(0) > max_samples:
        idx = torch.randperm(Y.size(0), device=Y.device)[:max_samples]
        Y = Y[idx]

    Z = torch.cat([X, Y], dim=0)
    sigma2 = _median_heuristic_sigma2(Z)
    mults = _parse_sigma_multipliers(default=[])

    # Compute kernels in chunks to save memory
    chunk_size = 500
    n, m = X.size(0), Y.size(0)
    
    # Kxx
    Kxx_sum = 0.0
    Kxx_diag_sum = 0.0
    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        for j in range(0, n, chunk_size):
            end_j = min(j + chunk_size, n)
            K_chunk = _rbf_kernel_mixture(X[i:end_i], X[j:end_j], sigma2, mults)
            if i == j:  # diagonal block
                Kxx_diag_sum += torch.diagonal(K_chunk).sum()
            Kxx_sum += K_chunk.sum()
    
    # Kyy
    Kyy_sum = 0.0
    Kyy_diag_sum = 0.0
    for i in range(0, m, chunk_size):
        end_i = min(i + chunk_size, m)
        for j in range(0, m, chunk_size):
            end_j = min(j + chunk_size, m)
            K_chunk = _rbf_kernel_mixture(Y[i:end_i], Y[j:end_j], sigma2, mults)
            if i == j:  # diagonal block
                Kyy_diag_sum += torch.diagonal(K_chunk).sum()
            Kyy_sum += K_chunk.sum()
    
    # Kxy
    Kxy_sum = 0.0
    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        for j in range(0, m, chunk_size):
            end_j = min(j + chunk_size, m)
            K_chunk = _rbf_kernel_mixture(X[i:end_i], Y[j:end_j], sigma2, mults)
            Kxy_sum += K_chunk.sum()

    # Unbiased MMD^2 estimator
    if n < 2 or m < 2:
        return 0.0
    mmd2 = (
        (Kxx_sum - Kxx_diag_sum) / (n * (n - 1))
        + (Kyy_sum - Kyy_diag_sum) / (m * (m - 1))
        - 2.0 * Kxy_sum / (n * m)
    )

    return float(torch.clamp(mmd2, min=0.0).item())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Divergences sanity check")
    args = parser.parse_args()
    print("This module provides calculate_mmd(). Implemented in Part 4.")
