"""ARCOS-specific metrics computation."""

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm

from .ot import compute_w1_feature_space, normalize_features as normalize_features_fn


def compute_output_discrepancy(
    model_Q: nn.Module,
    model_Qt: nn.Module,
    loader: DataLoader,
    mode: str = "logits_l2_mean",
    device: str = "cuda"
) -> float:
    """Compute output discrepancy between two models.
    
    Args:
        model_Q: First model (frozen)
        model_Qt: Second model (fine-tuned)
        loader: Data loader for evaluation
        mode: Discrepancy mode (logits_l2_mean, logits_l2_max, probs_kl)
        device: Device to use
        
    Returns:
        Output discrepancy value
    """
    model_Q.eval()
    model_Qt.eval()
    
    discrepancies = []
    
    with torch.no_grad():
        for data, _ in tqdm(loader, desc="Computing output discrepancy"):
            data = data.to(device)
            
            # Get outputs from both models
            output_Q = model_Q(data)
            output_Qt = model_Qt(data)
            
            if mode == "logits_l2_mean":
                # L2 norm of logits, averaged over batch
                diff = output_Q - output_Qt
                l2_norm = torch.norm(diff, p=2, dim=1)
                discrepancies.extend(l2_norm.cpu().numpy())
            
            elif mode == "logits_l2_max":
                # Maximum L2 norm of logits in batch
                diff = output_Q - output_Qt
                l2_norm = torch.norm(diff, p=2, dim=1)
                discrepancies.append(torch.max(l2_norm).cpu().item())
            
            elif mode == "probs_kl":
                # KL divergence between probability distributions
                probs_Q = torch.softmax(output_Q, dim=1)
                probs_Qt = torch.softmax(output_Qt, dim=1)
                
                # Add small epsilon to avoid log(0)
                eps = 1e-8
                probs_Qt = torch.clamp(probs_Qt, min=eps)
                
                kl_div = torch.sum(probs_Q * torch.log(probs_Q / probs_Qt), dim=1)
                discrepancies.extend(kl_div.cpu().numpy())
            
            else:
                raise ValueError(f"Unknown discrepancy mode: {mode}")
    
    if mode == "logits_l2_max":
        return np.mean(discrepancies)
    else:
        return np.mean(discrepancies)


def estimate_Lx(
    model: nn.Module,
    loader: DataLoader,
    method: str = "p99",
    device: str = "cuda"
) -> float:
    """Estimate Lipschitz constant L_x for the model.
    
    Args:
        model: Model to estimate L_x for
        loader: Data loader for estimation
        method: Estimation method (max, p99, ema_smooth)
        device: Device to use
        
    Returns:
        Estimated L_x value
    """
    model.eval()
    
    if method == "max":
        return _estimate_Lx_max(model, loader, device)
    elif method == "p99":
        return _estimate_Lx_percentile(model, loader, device, percentile=99)
    elif method == "ema_smooth":
        return _estimate_Lx_ema(model, loader, device, alpha=0.9)
    else:
        raise ValueError(f"Unknown L_x estimation method: {method}")


def _estimate_Lx_max(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda"
) -> float:
    """Estimate L_x using maximum gradient norm."""
    max_grad_norm = 0.0
    
    for data, _ in tqdm(loader, desc="Estimating L_x (max)"):
        data = data.to(device)
        data.requires_grad_(True)
        
        # Forward pass
        output = model(data)
        
        # Compute gradients w.r.t. input
        grad_output = torch.ones_like(output)
        gradients = torch.autograd.grad(
            output, data, grad_outputs=grad_output,
            create_graph=False, retain_graph=False
        )[0]
        
        # Compute gradient norm
        grad_norm = torch.norm(gradients, p=2, dim=1)
        max_grad_norm = max(max_grad_norm, torch.max(grad_norm).cpu().item())
    
    return max_grad_norm


def _estimate_Lx_percentile(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda",
    percentile: int = 99
) -> float:
    """Estimate L_x using percentile of gradient norms."""
    grad_norms = []
    
    for data, _ in tqdm(loader, desc=f"Estimating L_x (p{percentile})"):
        data = data.to(device)
        data.requires_grad_(True)
        
        # Forward pass
        output = model(data)
        
        # Compute gradients w.r.t. input
        grad_output = torch.ones_like(output)
        gradients = torch.autograd.grad(
            output, data, grad_outputs=grad_output,
            create_graph=False, retain_graph=False
        )[0]
        
        # Compute gradient norm
        grad_norm = torch.norm(gradients, p=2, dim=1)
        grad_norms.extend(grad_norm.cpu().numpy())
    
    return np.percentile(grad_norms, percentile)


def _estimate_Lx_ema(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda",
    alpha: float = 0.9
) -> float:
    """Estimate L_x using exponential moving average."""
    ema_Lx = 0.0
    count = 0
    
    for data, _ in tqdm(loader, desc="Estimating L_x (EMA)"):
        data = data.to(device)
        data.requires_grad_(True)
        
        # Forward pass
        output = model(data)
        
        # Compute gradients w.r.t. input
        grad_output = torch.ones_like(output)
        gradients = torch.autograd.grad(
            output, data, grad_outputs=grad_output,
            create_graph=False, retain_graph=False
        )[0]
        
        # Compute gradient norm
        grad_norm = torch.norm(gradients, p=2, dim=1)
        max_grad_norm = torch.max(grad_norm).cpu().item()
        
        # Update EMA
        if count == 0:
            ema_Lx = max_grad_norm
        else:
            ema_Lx = alpha * ema_Lx + (1 - alpha) * max_grad_norm
        
        count += 1
    
    return ema_Lx


def compute_bound_proxy(
    Lx: float,
    W1: float,
    output_discrepancy: float
) -> float:
    """Compute ARCOS bound proxy: Lx * W1 + output_discrepancy.
    
    Args:
        Lx: Estimated Lipschitz constant
        W1: Wasserstein-1 distance
        output_discrepancy: Output discrepancy term
        
    Returns:
        Bound proxy value
    """
    return Lx * W1 + output_discrepancy


def evaluate_risk(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda"
) -> float:
    """Evaluate risk (error rate) of model.
    
    Args:
        model: Model to evaluate
        loader: Data loader
        device: Device to use
        
    Returns:
        Risk (error rate) as percentage
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in tqdm(loader, desc="Evaluating risk"):
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    
    accuracy = 100. * correct / total
    risk = 100. - accuracy
    
    return risk


def compute_trace_metrics(
    Q: nn.Module,
    Qt: nn.Module,
    source_loader: DataLoader,
    target_loader: DataLoader,
    anchor_loader: DataLoader,
    device: str = "cuda",
    ot_method: str = "max-sliced",
    Lx_method: str = "p99",
    normalize_features: bool = True,
    **kwargs
) -> Dict[str, float]:
    """Compute all ARCOS metrics.
    
    Args:
        model_Q: Frozen model Q
        model_Qt: Fine-tuned model Q_tilde
        source_loader: Source domain data loader
        target_loader: Target domain data loader
        anchor_loader: Anchor evaluation data loader
        device: Device to use
        ot_method: Optimal transport method
        Lx_method: L_x estimation method
        normalize_features: Whether to normalize features
        **kwargs: Additional arguments
        
    Returns:
        Dictionary of ARCOS metrics
    """
    print("Computing ARCOS metrics...")
    
    # Extract features
    print("Extracting source features...")
    source_features = extract_features(Q, source_loader, device)
    print("Extracting target features...")
    target_features = extract_features(Qt, target_loader, device)
    
    # Normalize features if requested
    if normalize_features:
        source_features = normalize_features_fn(source_features)
        target_features = normalize_features_fn(target_features)
    
    # Compute W1 distance
    print("Computing W1 distance...")
    W1 = compute_w1_feature_space(
        source_features, target_features, method=ot_method, **kwargs
    )
    
    # Compute output discrepancy
    print("Computing output discrepancy...")
    output_disp = compute_output_discrepancy(
        Q, Qt, anchor_loader, device=device
    )
    
    # Estimate L_x
    print("Estimating L_x...")
    Lx = estimate_Lx(Qt, anchor_loader, method=Lx_method, device=device)
    
    # Compute bound proxy
    bound_proxy = compute_bound_proxy(Lx, W1, output_disp)
    
    # Compute risks
    print("Computing risks...")
    risk_Q = evaluate_risk(Q, anchor_loader, device=device)
    risk_Qt = evaluate_risk(Qt, anchor_loader, device=device)
    delta_R = abs(risk_Qt - risk_Q)
    
    metrics = {
        'W1': W1,
        'output_discrepancy': output_disp,
        'Lx': Lx,
        'bound_proxy': bound_proxy,
        'risk_Q': risk_Q,
        'risk_Qt': risk_Qt,
        'delta_R': delta_R
    }
    
    print("ARCOS metrics computed successfully!")
    return metrics


def extract_features(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda"
) -> torch.Tensor:
    """Extract features from model.
    
    Args:
        model: Model to extract features from
        loader: Data loader
        device: Device to use
        
    Returns:
        Feature tensor
    """
    model.eval()
    features_list = []
    
    with torch.no_grad():
        for data, _ in tqdm(loader, desc="Extracting features"):
            data = data.to(device)
            
            if hasattr(model, 'get_features'):
                features = model.get_features(data)
            else:
                # Fallback: use penultimate layer
                features = model.backbone(data)
                features = features.view(features.size(0), -1)
            
            features_list.append(features.cpu())
    
    return torch.cat(features_list, dim=0)

