#!/usr/bin/env python3
"""
Main entry point for IFC-ViT.

Provides CLI for running all stages of the Influence Function Cache pipeline.

Supports DDP (Distributed Data Parallel) for multi-GPU training:
    torchrun --nproc_per_node=2 python -m ifc_vit.main --stage build_ifc ...

Single-GPU usage (unchanged):
    python main.py --stage build_ifc --imagenet /path/to/imagenet --output /path/to/output
"""

import argparse
import os
import sys
import torch
import torch.distributed as dist
import numpy as np
from pathlib import Path
import time

from .logging_utils import (
    setup_logger, get_logger, log_dict, StageTimer,
    DDPState, is_rank0, is_ddp, barrier, broadcast_object,
    rank0_print, rank_tqdm,
)

# Setup logger - will be configured in main()
logger = None


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Influence Function Cache for ViT on ImageNet",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Stages:
  curvature_subset   Select high-curvature training subset using FastIF-style k-NN
  cluster            Compute gradient sketches and cluster training samples
  build_rhs          Accumulate cluster mean gradients (RHS for CG)
  build_precond      Build KFAC/EKFAC preconditioner
  build_ifc          Full IFC build: GGN operator + CG solves for all clusters
  query              Test query interface with sample influence computation
  all                Run all stages in sequence

Examples:
  # Build full IFC (single GPU)
  python main.py --stage build_ifc --imagenet /data/imagenet --output ./ifc_output

  # Build full IFC (multi-GPU with DDP)
  torchrun --nproc_per_node=2 python -m ifc_vit.main --stage build_ifc --imagenet /data/imagenet --output ./ifc_output

  # Run specific stage
  python main.py --stage cluster --imagenet /data/imagenet --output ./ifc_output

  # Query existing IFC
  python main.py --stage query --imagenet /data/imagenet --output ./ifc_output
        """
    )
    
    parser.add_argument(
        "--stage",
        type=str,
        required=True,
        choices=[
            "curvature_subset",
            "cluster",
            "build_rhs",
            "build_precond",
            "build_ifc",
            "query",
            "all",
        ],
        help="Stage to run"
    )
    
    parser.add_argument(
        "--imagenet",
        type=str,
        required=True,
        help="Path to ImageNet root directory"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output directory for all data"
    )
    
    # DDP parameters
    parser.add_argument(
        "--ddp",
        action="store_true",
        help="Enable DDP (auto-detected from WORLD_SIZE env var if not specified)"
    )
    
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for DDP (typically set by torchrun)"
    )
    
    parser.add_argument(
        "--ddp-backend",
        type=str,
        default="nccl",
        choices=["nccl", "gloo"],
        help="DDP backend (default: nccl)"
    )
    
    parser.add_argument(
        "--ddp-verbose",
        action="store_true",
        help="Enable verbose DDP logging on all ranks (debugging)"
    )
    
    # Curvature subset parameters
    parser.add_argument(
        "--n-anchors",
        type=int,
        default=200,
        help="Number of validation anchors for k-NN (default: 200)"
    )
    
    parser.add_argument(
        "--k-neighbors",
        type=int,
        default=1000,
        help="Number of neighbors per anchor (default: 50)"
    )
    
    parser.add_argument(
        "--max-curv-subset",
        type=int,
        default=10000,
        help="Maximum curvature subset size (default: 10000)"
    )
    
    # Clustering parameters
    parser.add_argument(
        "--n-clusters",
        type=int,
        default=100,
        help="Number of clusters (default: 500)"
    )
    
    parser.add_argument(
        "--sketch-dim",
        type=int,
        default=256,
        help="Gradient sketch dimension (default: 256)"
    )
    
    parser.add_argument(
        "--kmeans-max-iter",
        type=int,
        default=100,
        help="K-means maximum iterations (default: 100)"
    )
    
    parser.add_argument(
        "--kmeans-tol",
        type=float,
        default=1e-5,
        help="K-means convergence tolerance (default: 1e-4)"
    )
    
    # CG parameters
    parser.add_argument(
        "--damping",
        type=float,
        default=1,
        help="GGN damping parameter λ (default: 0.01)"
    )
    
    parser.add_argument(
        "--precond-type",
        type=str,
        default="none",
        choices=["kfac", "ekfac", "diagonal", "none"],
        help="Preconditioner type (default: none)"
    )
    
    parser.add_argument(
        "--max-cg-iter",
        type=int,
        default=30,
        help="Maximum CG iterations (default: 50)"
    )
    
    parser.add_argument(
        "--cg-tol",
        type=float,
        default=2e-2,
        help="CG convergence tolerance (default: 0.001)"
    )
    
    # Hardware parameters
    parser.add_argument(
        "--batch-size",
        type=int,
        default=256,
        help="Batch size for data loading (default: 64)"
    )
    
    parser.add_argument(
        "--curv-batch-size",
        type=int,
        default=40,
        help="Batch size for curvature computation (default: 6)"
    )
    
    parser.add_argument(
        "--num-workers",
        type=int,
        default=20,
        help="Number of data loading workers per GPU (default: 8). "
             "Consider setting to CPU_count / num_GPUs for optimal throughput."
    )
    
    parser.add_argument(
        "--prefetch-factor",
        type=int,
        default=8,
        help="Batches to prefetch per worker (default: 4). "
             "Higher values use more RAM but improve GPU utilization."
    )
    
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to use (default: cuda)"
    )
    
    # Other options
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume from previous run"
    )

    parser.add_argument(
        "--use_tiny",
        action="store_true",
        help="Use DeiT-Tiny instead of ViT-B/16"
    )
    
    parser.add_argument(
        "--use_streaming_rhs",
        action="store_true",
        help="Use streaming dataset for RHS vector accumulation"
    )   

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed (default: 42)"
    )
    
    return parser.parse_args()


def stage_curvature_subset(args):
    """Stage 1: Select high-curvature training subset.
    
    DDP: Runs only on rank 0 (fastif_select has internal barrier).
    """
    rank0_print("=" * 60)
    rank0_print("Stage: Curvature Subset Selection")
    rank0_print("=" * 60)
    
    from .vit_full import load_vit
    from .imagenet_loader import ImageNetDataset
    from .fastif_select import select_curvature_subset
    
    # Load model
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    
    # Load datasets
    train_dataset = ImageNetDataset(args.imagenet, split="train")
    val_dataset = ImageNetDataset(args.imagenet, split="train")
    
    # Select curvature subset (handles DDP internally)
    curv_idx = select_curvature_subset(
        model,
        train_dataset,
        val_dataset,
        args.output,
        n_anchors=args.n_anchors,
        k_neighbors=args.k_neighbors,
        max_subset_size=args.max_curv_subset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=args.device,
    )
    
    rank0_print(f"\nCurvature subset: {len(curv_idx):,} samples")


def stage_cluster(args):
    """Stage 2: Compute gradient sketches and cluster.
    
    DDP: Runs only on rank 0 (cluster_gradients has internal barrier).
    """
    rank0_print("=" * 60)
    rank0_print("Stage: Gradient Sketching & Clustering")
    rank0_print("=" * 60)
    
    from .vit_full import load_vit
    from .imagenet_loader import ImageNetDataset
    from .cluster import cluster_gradients
    from torch.utils.data import DataLoader, Dataset, Subset, DistributedSampler, TensorDataset
    # Load model
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    
    # # Load training data
    # train_dataset = ImageNetDataset(args.imagenet, split="train", subset_n=10000)
    # train_loader = train_dataset.get_loader(
    #     batch_size=args.batch_size,
    #     shuffle=False,
    #     num_workers=args.num_workers,
    #     drop_last=True
    # )
    image_data = torch.randn(1000, 3, 64, 64) 
    labels = torch.randint(0, 10, (1000,))  

    dataset = TensorDataset(image_data, labels)

    #Split into batches
    batch_size = 32
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    # Cluster gradients with cosine distance (handles DDP internally)
    cluster_ids, centroids, projector = cluster_gradients(
        model,
        dataloader,
        args.output,
        n_clusters=args.n_clusters,
        sketch_dim=args.sketch_dim,
        max_iter=args.kmeans_max_iter,
        tol=args.kmeans_tol,
        device=args.device,
    )
    
    rank0_print(f"\nClustering complete: {args.n_clusters} clusters")


def stage_build_rhs(args):
    """Stage 3: Build cluster mean gradients.
    
    DDP: Runs only on rank 0 (build_rhs_vectors has internal barrier).
    """
    rank0_print("=" * 60)
    rank0_print("Stage: Cluster Mean Gradient Accumulation")
    rank0_print("=" * 60)
    
    from .vit_full import load_vit
    from .rhs_build import build_rhs_vectors
    
    # Load cluster IDs
    cluster_ids_path = os.path.join(args.output, "cluster_id.npy")
    if not os.path.exists(cluster_ids_path):
        rank0_print("Error: cluster_id.npy not found. Run 'cluster' stage first.")
        sys.exit(1)
    
    cluster_ids = np.load(cluster_ids_path)
    
    # Load model
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    
    # Build RHS vectors (handles DDP internally)
    build_rhs_vectors(
        model,
        args.imagenet,
        cluster_ids,
        args.output,
        args.n_clusters,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=args.device,
        use_streaming=args.use_streaming_rhs,
    )
    
    rank0_print("\nRHS vectors complete")


def stage_build_precond(args):
    """Stage 4: Build KFAC preconditioner.
    
    DDP: Runs only on rank 0 (build_kfac_preconditioner has internal barrier).
    """
    return
    rank0_print("=" * 60)
    rank0_print("Stage: Preconditioner has been ")
    rank0_print("=" * 60)
    if args.precond_type == "none":
        rank0_print("Preconditioner type 'none' selected, skipping preconditioner build.")
        return
    rank0_print("=" * 60)
    rank0_print("Stage: KFAC Preconditioner Construction")
    rank0_print("=" * 60)
    
    from .vit_full import load_vit
    from .imagenet_loader import ImageNetDataset
    from .fastif_select import load_curvature_subset
    from .precond import build_kfac_preconditioner
    
    # Load model with hooks
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    model.register_kfac_hooks()
    
    # Load curvature subset
    curv_idx = load_curvature_subset(args.output)
    
    # Create loader
    dataset = ImageNetDataset(args.imagenet, split="train")
    curv_loader = dataset.get_subset_loader(
        curv_idx,
        batch_size=args.curv_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    
    # Build preconditioner (handles DDP internally)
    precond = build_kfac_preconditioner(
        model,
        curv_loader,
        args.output,
        precond_type=args.precond_type,
        damping=args.damping * 0.1,
        device=args.device,
    )
    model.clear_hooks()
    rank0_print(f"\n{args.precond_type.upper()} preconditioner complete")


def stage_build_ifc(args):
    """Stage 5: Full IFC build.
    
    DDP: All ranks participate in PCG solves (for GGN matvec all-reduce).
         Only rank 0 saves solutions.
    """
    rank0_print("=" * 60)
    rank0_print("Stage: Build Influence Function Cache")
    rank0_print("=" * 60)
    
    from .ifc_build import build_ifc
    
    build_ifc(
        args.imagenet,
        args.output,
        n_clusters=args.n_clusters,
        damping=args.damping,
        precond_type=args.precond_type,
        max_cg_iter=args.max_cg_iter,
        cg_tol=args.cg_tol,
        curv_batch_size=args.curv_batch_size,
        num_workers=args.num_workers,
        device=args.device,
        resume=args.resume,
    )


def stage_query(args):
    """Stage 6: Test query interface with k-fold CV and random subset estimation.
    
    DDP: Runs only on rank 0. Other ranks wait at barrier.
    """
    # DDP: Query stage runs only on rank 0
    if is_ddp() and not is_rank0():
        logger.info("DDP: Rank != 0, skipping query stage")
        barrier()
        return
    
    rank0_print("=" * 60)
    rank0_print("Stage: Query Test - K-Fold CV & Random Subset Influence")
    rank0_print("=" * 60)
    
    from .vit_full import load_vit
    from .imagenet_loader import ImageNetDataset
    from .ifc_build import load_ifc
    from .query import IFCQuery
    
    # Load IFC
    ifc = load_ifc(args.output, args.device)
    
    # Load model
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    
    # Create query interface
    query = IFCQuery(ifc, model, args.device)
    
    # Load cluster IDs
    cluster_ids = np.load(os.path.join(args.output, "cluster_id.npy"))
    n_train = len(cluster_ids)
    
    # Load validation data
    val_dataset = ImageNetDataset(args.imagenet, split="train")
    val_loader = val_dataset.get_loader(batch_size=1, shuffle=False, num_workers=args.num_workers)
    
    # ===== K-Fold Cross-Validation Influence Estimation =====
    k_folds = 5
    rank0_print(f"\n{'='*50}")
    rank0_print(f"K-Fold CV Influence Estimation (k={k_folds})")
    rank0_print(f"{'='*50}")
    
    fold_size = n_train // k_folds
    fold_results = []
    
    for fold in range(k_folds):
        # Define fold indices (samples to "remove")
        fold_start = fold * fold_size
        fold_end = (fold + 1) * fold_size if fold < k_folds - 1 else n_train
        fold_indices = np.arange(fold_start, fold_end)
        
        rank0_print(f"\nFold {fold+1}/{k_folds}: removing samples {fold_start:,}-{fold_end:,} ({len(fold_indices):,} samples)")
        
        # Compute predicted loss change on validation set
        mean_loss_change, scores = query.compute_fold_loss_change(
            val_loader,
            fold_indices,
            cluster_ids,
            max_samples=100,  # Limit for speed
        )
        
        fold_results.append({
            'fold': fold + 1,
            'n_removed': len(fold_indices),
            'mean_loss_change': mean_loss_change,
            'std_loss_change': float(np.std(scores)) if scores else 0,
        })
        
        rank0_print(f"  Predicted mean loss change: {mean_loss_change:.6f}")
        rank0_print(f"  Std: {np.std(scores):.6f}" if scores else "  No scores")
    
    # Summary
    rank0_print(f"\n{'='*50}")
    rank0_print("K-Fold CV Summary:")
    mean_changes = [r['mean_loss_change'] for r in fold_results]
    rank0_print(f"  Mean across folds: {np.mean(mean_changes):.6f}")
    rank0_print(f"  Std across folds: {np.std(mean_changes):.6f}")
    
    # ===== Random Subset Influence Estimation =====
    n_random_subsets = 5
    subset_fractions = [0.01, 0.05, 0.10]  # 1%, 5%, 10% of training data
    
    rank0_print(f"\n{'='*50}")
    rank0_print(f"Random Subset Influence Estimation")
    rank0_print(f"{'='*50}")
    
    for frac in subset_fractions:
        subset_size = int(n_train * frac)
        rank0_print(f"\n--- Subset size: {frac*100:.0f}% ({subset_size:,} samples) ---")
        
        subset_results = []
        
        for trial in range(n_random_subsets):
            # Random subset
            np.random.seed(args.seed + trial)
            subset_indices = np.random.choice(n_train, size=subset_size, replace=False)
            
            # Compute influence
            mean_loss_change, scores = query.compute_fold_loss_change(
                val_loader,
                subset_indices,
                cluster_ids,
                max_samples=50,  # Limit for speed
            )
            
            subset_results.append(mean_loss_change)
            rank0_print(f"  Trial {trial+1}: predicted loss change = {mean_loss_change:.6f}")
        
        rank0_print(f"  Mean across trials: {np.mean(subset_results):.6f}")
        rank0_print(f"  Std across trials: {np.std(subset_results):.6f}")
    
    # ===== Class-specific removal test =====
    rank0_print(f"\n{'='*50}")
    rank0_print("Class-Specific Removal Test")
    rank0_print(f"{'='*50}")
    
    # Get class labels from dataset
    train_dataset = ImageNetDataset(args.imagenet, split="train")
    
    # Test removing samples from a few classes
    test_classes = [0, 100, 500]  # Sample class indices
    
    for class_idx in test_classes:
        # Find samples belonging to this class (approximate via cluster)
        # For a proper test, we'd need class labels - use cluster as proxy
        class_samples = np.where(cluster_ids == class_idx % args.n_clusters)[0]
        
        if len(class_samples) == 0:
            continue
            
        rank0_print(f"\nRemoving cluster {class_idx % args.n_clusters} ({len(class_samples):,} samples)")
        
        mean_loss_change, scores = query.compute_fold_loss_change(
            val_loader,
            class_samples,
            cluster_ids,
            max_samples=50,
        )
        
        rank0_print(f"  Predicted loss change: {mean_loss_change:.6f}")
    
    # Clear cache
    query.clear_cache()
    
    rank0_print(f"\n{'='*50}")
    rank0_print("Query test complete!")
    rank0_print(f"{'='*50}")
    
    # Barrier for DDP synchronization
    if is_ddp():
        barrier()


def stage_all(args):
    """Run all stages in sequence."""
    rank0_print("=" * 60)
    rank0_print("Running All Stages")
    rank0_print("=" * 60)
    
    stages = [
        ("curvature_subset", stage_curvature_subset),
        ("cluster", stage_cluster),
        ("build_rhs", stage_build_rhs),
        ("build_precond", stage_build_precond),
        ("build_ifc", stage_build_ifc),
        ("query", stage_query),
    ]
    
    for name, func in stages:
        rank0_print(f"\n{'=' * 60}")
        rank0_print(f"Starting stage: {name}")
        rank0_print(f"{'=' * 60}\n")
        func(args)
        # Synchronize between stages
        barrier()
    
    rank0_print("\n" + "=" * 60)
    rank0_print("All stages complete!")
    rank0_print("=" * 60)


def main():
    """Main entry point."""
    global logger
    
    args = parse_args()
    
    # ==========================================================================
    # DDP Initialization (before any other setup)
    # ==========================================================================
    
    # Auto-detect DDP from environment
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if args.ddp or world_size > 1:
        DDPState.initialize(backend=args.ddp_backend)
        args.device = DDPState.device()
    
    # Store DDP state in args for convenience
    args.is_ddp = is_ddp()
    args.is_rank0 = is_rank0()
    args.world_size = DDPState.world_size()
    args.rank = DDPState.rank()
    args.local_rank = DDPState.local_rank()
    
    # ==========================================================================
    # Setup (rank 0 creates output dir, others wait)
    # ==========================================================================
    
    if is_rank0():
        os.makedirs(args.output, exist_ok=True)
    barrier()  # Wait for rank 0 to create dir
    
    # Setup logging (rank0_only by default unless ddp_verbose)
    log_file = os.path.join(args.output, "ifc.log") if is_rank0() else None
    setup_logger("ifc_vit", log_dir=args.output if is_rank0() else None, 
                 log_file=log_file, rank0_only=not args.ddp_verbose)
    logger = get_logger(__name__)
    
    logger.info("=" * 60)
    logger.info("IFC-ViT Pipeline Started")
    if is_ddp():
        logger.info(f"DDP Mode: world_size={args.world_size}, rank={args.rank}, local_rank={args.local_rank}")
    logger.info("=" * 60)
    
    # Set random seed (same on all ranks for reproducibility)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    
    # Check CUDA availability
    if "cuda" in args.device and not torch.cuda.is_available():
        logger.warning("CUDA not available, falling back to CPU")
        args.device = "cpu"
    
    if "cuda" in args.device:
        device_idx = args.local_rank if is_ddp() else 0
        logger.info(f"Using GPU {device_idx}: {torch.cuda.get_device_name(device_idx)}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(device_idx).total_memory / 1e9:.1f} GB")
    
    # Log configuration
    log_dict(logger, "Configuration", {
        'imagenet': args.imagenet,
        'output': args.output,
        'n_clusters': args.n_clusters,
        'damping': args.damping,
        'precond_type': args.precond_type,
        'max_cg_iter': args.max_cg_iter,
        'cg_tol': args.cg_tol,
        'device': args.device,
        'seed': args.seed,
        'stage': args.stage,
        'is_ddp': args.is_ddp,
        'world_size': args.world_size,
        'rank': args.rank,
    })
    
    # Route to appropriate stage
    stage_map = {
        "curvature_subset": stage_curvature_subset,
        "cluster": stage_cluster,
        "build_rhs": stage_build_rhs,
        "build_precond": stage_build_precond,
        "build_ifc": stage_build_ifc,
        "query": stage_query,
        "all": stage_all,
    }
    
    stage_timer = StageTimer()
    
    try:
        stage_func = stage_map[args.stage]
        with stage_timer.stage(args.stage):
            stage_func(args)
        
        # Ensure all ranks complete before summarizing
        barrier()
        
        stage_timer.log_summary(logger)
        logger.info("Pipeline completed successfully")
        
    except Exception as e:
        logger.error(f"Pipeline failed: {e}", exc_info=True)
        raise
    finally:
        # Clean up DDP
        if is_ddp():
            DDPState.cleanup()


if __name__ == "__main__":
    main()
