"""
FastIF-style curvature subset selection.

Implements the k-NN based method from "Fast Influence Functions" 
(https://arxiv.org/pdf/2012.15781) to select a high-curvature subset
of training samples for efficient influence computation.

Goal: Pick ~10k training samples with largest curvature contribution.

DDP-aware: Only rank 0 computes subset, then broadcasts to all ranks.
"""

import torch
import numpy as np
from typing import Optional, Tuple
import os
from tqdm import tqdm

from .vit_full import ViTWithHooks
from .imagenet_loader import ImageNetDataset
from .logging_utils import (
    get_logger, log_tensor_stats, log_dict, StageTimer,
    is_rank0, is_ddp, barrier, broadcast_object, rank_tqdm,
)

logger = get_logger("ifc_vit.fastif")


def compute_embeddings_chunked(
    model: ViTWithHooks,
    dataset: ImageNetDataset,
    chunk_size: int = 200_000,
    batch_size: int = 64,
    num_workers: int = 8,
    save_path: Optional[str] = None,
) -> torch.Tensor:
    """
    Compute penultimate embeddings for all samples in dataset.
    
    Uses chunked processing to avoid memory issues with large datasets.
    
    Args:
        model: ViT model with hooks
        dataset: ImageNet dataset
        chunk_size: Number of samples per chunk for disk writing
        batch_size: Batch size for forward pass
        num_workers: Data loading workers
        save_path: Optional path to save embeddings as .npy
        
    Returns:
        Embeddings tensor (N, D) if save_path is None, else None
    """
    model.model.eval()
    
    all_embeddings = []
    total_processed = 0
    
    loader = dataset.get_loader(
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
    
    logger.info(f"Computing embeddings for {len(dataset):,} samples (batch_size={batch_size})")
    
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(tqdm(loader, desc="Computing embeddings")):
            images = images.to(model.device)
            emb = model.get_penultimate_features(images)
            all_embeddings.append(emb.cpu())
            total_processed += len(images)
            
            # Log progress periodically
            if (batch_idx + 1) % 500 == 0:
                logger.debug(f"Processed {total_processed:,}/{len(dataset):,} samples")
    
    embeddings = torch.cat(all_embeddings, dim=0)
    
    # Log embedding statistics
    log_tensor_stats(logger, "Embeddings", embeddings)
    logger.info(f"Embedding dimension: {embeddings.shape[1]}")
    
    if save_path:
        np.save(save_path, embeddings.numpy().astype(np.float16))
        logger.info(f"Saved embeddings to {save_path} ({embeddings.nbytes / 1e9:.2f} GB)")
    
    return embeddings


def chunked_knn_search(
    query_embeddings: torch.Tensor,
    train_embeddings_path: str,
    k: int = 50,
    chunk_size: int = 200_000,
    device: str = "cuda",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Perform chunked k-NN search over training embeddings.
    
    Loads training embeddings in chunks to avoid memory issues.
    
    Args:
        query_embeddings: (Q, D) query embeddings (e.g., validation samples)
        train_embeddings_path: Path to training embeddings .npy file
        k: Number of nearest neighbors per query
        chunk_size: Number of training samples to load at once
        device: Device for computation
        
    Returns:
        distances: (Q, k) distances to nearest neighbors
        indices: (Q, k) indices of nearest neighbors in training set
    """
    # Load training embeddings with memmap for memory efficiency
    train_emb = np.load(train_embeddings_path, mmap_mode='r')
    n_train = train_emb.shape[0]
    
    query_embeddings = query_embeddings.to(device)
    n_queries = query_embeddings.shape[0]
    
    # Initialize results
    best_distances = torch.full((n_queries, k), float('inf'), device=device)
    best_indices = torch.zeros((n_queries, k), dtype=torch.long, device=device)
    
    logger.info(f"Running k-NN search: {n_queries} queries, {n_train:,} candidates, k={k}")
    logger.debug(f"Query embedding shape: {query_embeddings.shape}, chunk_size={chunk_size}")
    
    # Track distance statistics
    all_min_dists = []
    
    # Process training embeddings in chunks
    for start_idx in tqdm(range(0, n_train, chunk_size), desc="k-NN chunks"):
        end_idx = min(start_idx + chunk_size, n_train)
        
        # Load chunk and convert to tensor
        chunk_emb = torch.from_numpy(
            train_emb[start_idx:end_idx].astype(np.float32)
        ).to(device)
        
        # Compute L2 distances: ||q - t||^2 = ||q||^2 + ||t||^2 - 2<q,t>
        # For normalized vectors, this simplifies, but we compute full L2
        q_norm = (query_embeddings ** 2).sum(dim=1, keepdim=True)
        t_norm = (chunk_emb ** 2).sum(dim=1, keepdim=True).T
        distances = q_norm + t_norm - 2 * query_embeddings @ chunk_emb.T
        
        # Get top-k for this chunk
        chunk_k = min(k, end_idx - start_idx)
        chunk_dists, chunk_idx = distances.topk(chunk_k, dim=1, largest=False)
        
        # Track statistics
        all_min_dists.append(chunk_dists[:, 0].cpu())
        
        # Adjust indices to global training indices
        chunk_idx = chunk_idx + start_idx
        
        # Merge with current best
        combined_dists = torch.cat([best_distances, chunk_dists], dim=1)
        combined_idx = torch.cat([best_indices, chunk_idx], dim=1)
        
        # Keep only top-k overall
        _, top_k_idx = combined_dists.topk(k, dim=1, largest=False)
        best_distances = torch.gather(combined_dists, 1, top_k_idx)
        best_indices = torch.gather(combined_idx, 1, top_k_idx)
        
        # Clear cache
        del chunk_emb, distances
        torch.cuda.empty_cache()
    
    # Log k-NN statistics
    final_distances = best_distances.cpu().numpy()
    all_min_dists = torch.cat(all_min_dists).numpy()
    
    log_dict(logger, "k-NN distance statistics", {
        'min_distance': float(final_distances.min()),
        'max_distance': float(final_distances.max()),
        'mean_distance': float(final_distances.mean()),
        'std_distance': float(final_distances.std()),
        'median_distance': float(np.median(final_distances)),
    })
    
    return final_distances, best_indices.cpu().numpy()


def select_curvature_subset(
    model: ViTWithHooks,
    train_dataset: ImageNetDataset,
    val_dataset: ImageNetDataset,
    output_dir: str,
    n_anchors: int = 200,
    k_neighbors: int = 50,
    max_subset_size: int = 10_000,
    chunk_size: int = 200_000,
    batch_size: int = 64,
    num_workers: int = 8,
    device: str = "cuda",
) -> np.ndarray:
    """
    Select high-curvature training subset using FastIF-style k-NN.
    
    DDP behavior:
    - Only rank 0 performs the computation and saves to disk
    - All ranks receive the same indices via broadcast
    
    Algorithm:
    1. Compute penultimate embeddings for all training samples
    2. Randomly sample n_anchors from validation set
    3. For each anchor, find k_neighbors nearest training samples
    4. Take union of all neighbors → curvature subset
    
    Args:
        model: ViT model
        train_dataset: Training dataset
        val_dataset: Validation dataset
        output_dir: Directory to save outputs
        n_anchors: Number of validation anchors to sample
        k_neighbors: Number of neighbors per anchor
        max_subset_size: Maximum size of curvature subset
        chunk_size: Chunk size for embedding computation
        batch_size: Batch size for forward passes
        num_workers: Data loading workers
        device: Computation device
        
    Returns:
        curv_idx: Array of training indices in curvature subset (same on all ranks)
    """
    curv_idx_path = os.path.join(output_dir, "curv_idx.npy")
    train_emb_path = os.path.join(output_dir, "train_embeddings.npy")
    
    # ==========================================================================
    # DDP: Only rank 0 performs computation, others wait and load result
    # ==========================================================================
    
    if is_ddp():
        # Check if already computed
        if os.path.exists(curv_idx_path):
            # All ranks load from disk
            barrier()
            curv_idx = np.load(curv_idx_path)
            logger.info(f"Loaded existing curvature subset: {len(curv_idx):,} samples")
            return curv_idx
        
        # Only rank 0 computes
        if not is_rank0():
            logger.info("Waiting for rank 0 to compute curvature subset...")
            barrier()  # Wait for rank 0 to finish
            curv_idx = np.load(curv_idx_path)
            logger.info(f"Loaded curvature subset from rank 0: {len(curv_idx):,} samples")
            return curv_idx
    
    # Rank 0 (or single GPU) performs computation
    if is_rank0():
        os.makedirs(output_dir, exist_ok=True)
    
    logger.info("=" * 50)
    logger.info("FastIF Curvature Subset Selection")
    logger.info("=" * 50)
    log_dict(logger, "Configuration", {
        'n_anchors': n_anchors,
        'k_neighbors': k_neighbors,
        'max_subset_size': max_subset_size,
        'chunk_size': chunk_size,
        'output_dir': output_dir,
    })
    
    # Step 1: Compute training embeddings (if not already done)
    if not os.path.exists(train_emb_path):
        logger.info("Step 1: Computing training embeddings...")
        with StageTimer(logger, "Training embedding computation"):
            compute_embeddings_chunked(
                model, train_dataset,
                chunk_size=chunk_size,
                batch_size=batch_size,
                num_workers=num_workers,
                save_path=train_emb_path,
            )
    else:
        logger.info(f"Step 1: Using cached training embeddings from {train_emb_path}")
    
    # Step 2: Sample validation anchors
    logger.info(f"Step 2: Sampling {n_anchors} validation anchors...")
    n_val = len(val_dataset)
    anchor_indices = np.random.choice(n_val, size=min(n_anchors, n_val), replace=False)
    logger.debug(f"Anchor indices (first 10): {anchor_indices[:10].tolist()}")
    
    # Get anchor embeddings
    anchor_loader = val_dataset.get_subset_loader(
        anchor_indices,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
    
    anchor_embeddings = []
    model.model.eval()
    with torch.no_grad():
        for images, _ in anchor_loader:
            print("images shape:", images.shape)
            images = images.to(device)
            emb = model.get_penultimate_features(images)
            anchor_embeddings.append(emb.cpu())
    
    anchor_embeddings = torch.cat(anchor_embeddings, dim=0)
    logger.info(f"Anchor embeddings shape: {anchor_embeddings.shape}")
    log_tensor_stats(logger, "Anchor embeddings", anchor_embeddings)
    
    # Step 3: k-NN search
    logger.info("Step 3: Running k-NN search for curvature subset...")
    with StageTimer(logger, "k-NN search"):
        distances, indices = chunked_knn_search(
            anchor_embeddings,
            train_emb_path,
            k=k_neighbors,
            chunk_size=chunk_size,
            device=device,
        )
    
    # Step 4: Take union of all neighbors
    curv_idx = np.unique(indices.flatten())
    logger.info(f"Step 4: Initial curvature subset size: {len(curv_idx):,}")
    
    # Analyze neighbor frequency
    idx_flat = indices.flatten()
    unique, counts = np.unique(idx_flat, return_counts=True)
    
    log_dict(logger, "Neighbor frequency statistics", {
        'unique_neighbors': len(unique),
        'max_frequency': int(counts.max()),
        'mean_frequency': float(counts.mean()),
        'median_frequency': float(np.median(counts)),
        'samples_appearing_once': int((counts == 1).sum()),
        'samples_appearing_5+_times': int((counts >= 5).sum()),
    })
    
    # Limit to max size if needed (take samples that appear most frequently)
    if len(curv_idx) > max_subset_size:
        # Sort by count (descending) and take top max_subset_size
        sorted_idx = np.argsort(-counts)
        curv_idx = unique[sorted_idx[:max_subset_size]]
        logger.info(f"Limited to top {max_subset_size:,} most frequent neighbors")
        logger.debug(f"Min frequency in selected: {counts[sorted_idx[max_subset_size-1]]}")
    
    # Save curvature subset indices
    np.save(curv_idx_path, curv_idx)
    logger.info(f"Saved curvature subset ({len(curv_idx):,} samples) to {curv_idx_path}")
    
    # Final validation logging
    log_dict(logger, "Curvature subset summary", {
        'subset_size': len(curv_idx),
        'fraction_of_train': len(curv_idx) / len(train_dataset),
        'index_range': [int(curv_idx.min()), int(curv_idx.max())],
    })
    
    # DDP: Signal other ranks that computation is complete
    if is_ddp():
        barrier()
    
    return curv_idx


def load_curvature_subset(output_dir: str) -> np.ndarray:
    """Load previously computed curvature subset indices."""
    curv_idx_path = os.path.join(output_dir, "curv_idx.npy")
    if not os.path.exists(curv_idx_path):
        raise FileNotFoundError(f"Curvature subset not found at {curv_idx_path}")
    return np.load(curv_idx_path)


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python fastif_select.py /path/to/imagenet /path/to/output")
        sys.exit(1)
    
    imagenet_root = sys.argv[1]
    output_dir = sys.argv[2]
    
    from .vit_full import load_vit
    
    model = load_vit(pretrained=True, device="cuda")
    
    train_dataset = ImageNetDataset(imagenet_root, split="train")
    val_dataset = ImageNetDataset(imagenet_root, split="val")
    
    curv_idx = select_curvature_subset(
        model,
        train_dataset,
        val_dataset,
        output_dir,
        n_anchors=200,
        k_neighbors=50,
        max_subset_size=10_000,
    )
    
    print(f"Selected {len(curv_idx):,} training samples for curvature computation")
