#!/usr/bin/env python3
"""
Truncated K-fold experiment: Exact IF subset ground-truth vs CiF.

This script:
1. Partitions ImageNet-1k training set into K folds
2. Samples K_eval folds as evaluation/query folds
3. For each evaluation fold v:
   - Computes fold gradient g_v (mean gradient)
   - Computes ONE IHVP: s_v = (G + λI)^{-1} g_v using PCG
   - Scores every other fold t != v using dot products (no more IHVPs)
   - Compares Exact IF vs CiF via Spearman rank correlation

Key insight: Because H = (G+λI) is symmetric, we have:
    g_v^T H^{-1} g_t == g_t^T H^{-1} g_v
So one solve s_v = H^{-1} g_v lets us score all targets via dot products.

Usage:
    python -m ifc_vit.truncated_kfold \
        --imagenet /path/to/imagenet \
        --ifc-path /path/to/ifc_output \
        --output ./kfold_results \
        --k-folds 5000 \
        --k-eval 50
"""

import argparse
import json
import os
import time
from datetime import datetime
from typing import Tuple, Optional

import numpy as np
import torch
from torch.backends.cuda import sdp_kernel
from tqdm import tqdm

from .vit_full import load_vit, ViTWithHooks
from .imagenet_loader import ImageNetDataset, make_distributed_subset_loader
from .fastif_select import load_curvature_subset
from .ggn_ops import GGNOperator
from .pcg import pcg_solve
from .ifc_build import load_ifc
from .query import IFCQuery
from .logging_utils import get_logger, log_dict, rank0_print

logger = get_logger(__name__)


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Truncated K-fold experiment: Exact IF vs CiF comparison",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    
    # Required paths
    parser.add_argument(
        "--imagenet",
        type=str,
        required=True,
        help="Path to ImageNet root directory"
    )
    parser.add_argument(
        "--ifc-path",
        type=str,
        required=True,
        help="Path to CiF cache directory (IFCBuilder output)"
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output directory for experiment results"
    )
    
    # Experiment knobs
    parser.add_argument(
        "--k-folds",
        type=int,
        default=5000,
        help="Number of folds to partition training set into (default: 5000)"
    )
    parser.add_argument(
        "--k-eval",
        type=int,
        default=50,
        help="Number of evaluation/query folds to sample (default: 50)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed (default: 42)"
    )
    
    # Exact IF knobs
    parser.add_argument(
        "--damping",
        type=float,
        default=0.5,
        help="GGN damping parameter λ (should match CiF cache)"
    )
    parser.add_argument(
        "--max-cg-iter",
        type=int,
        default=30,
        help="Maximum CG iterations (default: 30)"
    )
    parser.add_argument(
        "--cg-tol",
        type=float,
        default=2e-3,
        help="CG convergence tolerance (default: 2e-2)"
    )
    parser.add_argument(
        "--curv-batch-size",
        type=int,
        default=400,
        help="Batch size for curvature computation (default: 40)"
    )
    parser.add_argument(
        "--grad-batch-size",
        type=int,
        default=4096,
        help="Batch size for gradient computation (default: 64)"
    )
    parser.add_argument(
        "--microbatch-size",
        type=int,
        default=32,
        help="Microbatch size for gradient accumulation (default: 8)"
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=8,
        help="Number of data loading workers (default: 20)"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to use (default: cuda)"
    )
    parser.add_argument(
        "--use-tiny",
        action="store_true",
        help="Use DeiT-Tiny instead of ViT-B/16"
    )
    
    # Scoring convention
    parser.add_argument(
        "--no-target-sum",
        action="store_true",
        help="If set, do NOT multiply target fold gradient by fold size"
    )
    
    # Debug
    parser.add_argument(
        "--debug-small",
        action="store_true",
        help="Run small debug check (K=10, K_eval=2) to verify symmetry"
    )
    parser.add_argument(
        "--subset-n",
        type=int,
        default=None,
        help="Limit dataset to first N samples (for debugging)"
    )
    
    return parser.parse_args()


# =============================================================================
# Fold partitioning utilities
# =============================================================================

def create_kfold_partition(N: int, K: int, seed: int) -> np.ndarray:
    """
    Create a deterministic K-fold partition of N samples.
    
    Args:
        N: Total number of samples
        K: Number of folds
        seed: Random seed for shuffling
        
    Returns:
        fold_ids: Array of shape (N,) with values 0..K-1
    """
    rng = np.random.default_rng(seed)
    indices = np.arange(N)
    rng.shuffle(indices)
    
    # Assign folds - roughly equal sizes, last fold gets remainder
    fold_ids = np.zeros(N, dtype=np.int32)
    fold_size = N // K
    remainder = N % K
    
    start = 0
    for k in range(K):
        # First 'remainder' folds get one extra sample
        size = fold_size + (1 if k < remainder else 0)
        fold_ids[indices[start:start + size]] = k
        start += size
    
    return fold_ids


def get_fold_indices(fold_ids: np.ndarray, k: int) -> np.ndarray:
    """Get indices of samples belonging to fold k."""
    return np.where(fold_ids == k)[0]


def get_fold_sizes(fold_ids: np.ndarray, K: int) -> np.ndarray:
    """Get size of each fold."""
    return np.bincount(fold_ids, minlength=K).astype(np.int32)


# =============================================================================
# Gradient computation
# =============================================================================

def compute_fold_grad_mean(
    model: ViTWithHooks,
    dataset: ImageNetDataset,
    indices: np.ndarray,
    batch_size: int = 64,
    microbatch_size: int = 8,
    num_workers: int = 8,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Compute mean gradient for a fold (subset of samples).
    
    Uses summed loss per microbatch for efficiency (not per-sample gradients).
    
    Args:
        model: ViT model with hooks
        dataset: ImageNet dataset
        indices: Indices of samples in this fold
        batch_size: Batch size for loading
        microbatch_size: Microbatch size for gradient accumulation
        num_workers: Data loading workers
        device: Computation device
        
    Returns:
        Mean gradient vector (flattened, fp32)
    """
    model.model.eval()
    for p in model.model.parameters():
        p.requires_grad_(True)
    
    params = [p for p in model.model.parameters() if p.requires_grad]
    num_params = sum(p.numel() for p in params)
    
    # Get subset loader
    loader = dataset.get_subset_loader(
        indices,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False,
        persistent_workers=False
    )
    
    grad_sum = torch.zeros(num_params, device=device, dtype=torch.float32)
    total_samples = 0
    
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        bs = images.shape[0]
        
        # Process in microbatches
        for start in range(0, bs, microbatch_size):
            end = min(start + microbatch_size, bs)
            mb_images = images[start:end]
            mb_labels = labels[start:end]
            mb_size = end - start
            
            # Zero gradients
            model.model.zero_grad()
            
            # Forward pass
            with torch.amp.autocast('cuda', dtype=torch.float16):
                logits = model.model(mb_images)
                loss = torch.nn.functional.cross_entropy(
                    logits, mb_labels, reduction='sum'
                )
            
            # Backward pass
            loss.backward()
            
            # Accumulate flattened gradients
            grad_flat = torch.cat([
                p.grad.flatten() if p.grad is not None else torch.zeros(p.numel(), device=device)
                for p in params
            ])
            grad_sum += grad_flat
            total_samples += mb_size
    
    # Return mean gradient
    return grad_sum / total_samples


def load_or_compute_fold_grad(
    model: ViTWithHooks,
    dataset: ImageNetDataset,
    fold_ids: np.ndarray,
    fold_k: int,
    cache_dir: str,
    batch_size: int = 64,
    microbatch_size: int = 8,
    num_workers: int = 8,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Load fold gradient from cache or compute and save it.
    
    Args:
        model: ViT model
        dataset: ImageNet dataset
        fold_ids: Fold assignments for all samples
        fold_k: Which fold to get gradient for
        cache_dir: Directory to cache gradients
        batch_size, microbatch_size, num_workers, device: Computation params
        
    Returns:
        Mean gradient for fold k (fp32, on specified device)
    """
    os.makedirs(cache_dir, exist_ok=True)
    cache_path = os.path.join(cache_dir, f"fold_{fold_k}.pt")
    
    if os.path.exists(cache_path):
        return torch.load(cache_path, map_location=device).float()
    
    # Compute gradient
    indices = get_fold_indices(fold_ids, fold_k)
    grad_mean = compute_fold_grad_mean(
        model, dataset, indices,
        batch_size=batch_size,
        microbatch_size=microbatch_size,
        num_workers=num_workers,
        device=device,
    )
    
    # Save to cache (fp16 on CPU)
    torch.save(grad_mean.half().cpu(), cache_path)
    
    return grad_mean


# =============================================================================
# Rank correlation (no scipy dependency)
# =============================================================================

def spearmanr(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute Spearman rank correlation coefficient.
    
    Args:
        x, y: Arrays of same length
        
    Returns:
        Spearman rho correlation coefficient
    """
    def rankdata(arr):
        """Assign ranks to data (average rank for ties)."""
        sorter = np.argsort(arr)
        inv = np.empty_like(sorter)
        inv[sorter] = np.arange(len(arr))
        
        # Handle ties by averaging ranks
        sorted_arr = arr[sorter]
        obs = np.r_[True, sorted_arr[1:] != sorted_arr[:-1]]
        dense = obs.cumsum()[inv]
        
        # Count occurrences of each rank
        count = np.r_[np.nonzero(obs)[0], len(arr)]
        return 0.5 * (count[dense] + count[dense - 1] + 1)
    
    x = np.asarray(x)
    y = np.asarray(y)
    
    if len(x) != len(y):
        raise ValueError("x and y must have same length")
    
    if len(x) < 2:
        return np.nan
    
    # Convert to ranks
    rx = rankdata(x)
    ry = rankdata(y)
    
    # Pearson correlation on ranks
    mx, my = rx.mean(), ry.mean()
    xm, ym = rx - mx, ry - my
    
    num = np.sum(xm * ym)
    den = np.sqrt(np.sum(xm ** 2) * np.sum(ym ** 2))
    
    if den == 0:
        return 0.0
    
    return num / den


def pearsonr(x: np.ndarray, y: np.ndarray) -> float:
    """Compute Pearson correlation coefficient."""
    x = np.asarray(x)
    y = np.asarray(y)
    
    if len(x) != len(y):
        raise ValueError("x and y must have same length")
    
    if len(x) < 2:
        return np.nan
    
    mx, my = x.mean(), y.mean()
    xm, ym = x - mx, y - my
    
    num = np.sum(xm * ym)
    den = np.sqrt(np.sum(xm ** 2) * np.sum(ym ** 2))
    
    if den == 0:
        return 0.0
    
    return num / den




# =============================================================================
# Main experiment
# =============================================================================

def main():
    args = parse_args()
    
    print("=" * 60)
    print("Truncated K-fold Experiment: Exact IF vs CiF")
    print("=" * 60)
    print(f"Start time: {datetime.now().isoformat()}")
    
    # Create output directory
    os.makedirs(args.output, exist_ok=True)
    
    # -------------------------------------------------------------------------
    # Step 2: Load model + dataset
    # -------------------------------------------------------------------------
    print("\n[Step 1] Loading model and dataset...")
    model = load_vit(pretrained=True, device=args.device, use_tiny=args.use_tiny)
    model.register_kfac_hooks()
    
    dataset = ImageNetDataset(args.imagenet, split="train", subset_n=args.subset_n)
    N = len(dataset.dataset)
    print(f"  Dataset size: N = {N:,}")
    if args.subset_n:
        print(f"  (Limited to {args.subset_n} samples for debugging)")
    
    # -------------------------------------------------------------------------
    # Step 3: Load CiF cache + cluster IDs
    # -------------------------------------------------------------------------
    print("\n[Step 2] Loading CiF cache...")
    ifc = load_ifc(args.ifc_path, args.device)
    query = IFCQuery(ifc, model, args.device)
    
    cluster_ids_path = os.path.join(args.ifc_path, "cluster_id.npy")
    cluster_ids = np.load(cluster_ids_path)
    print(f"  Loaded cluster_ids: {len(cluster_ids):,} samples")
    print(f"  Unique clusters: {len(np.unique(cluster_ids)):,}")
    print(cluster_ids[:10])
    # If using subset, truncate cluster_ids to match
    if args.subset_n and args.subset_n < len(cluster_ids):
        cluster_ids = cluster_ids[:args.subset_n]
        print(f"  Truncated cluster_ids to {len(cluster_ids):,} samples")
    
    if len(cluster_ids) != N:
        raise ValueError(f"cluster_ids length ({len(cluster_ids)}) != dataset size ({N})")
    
    # -------------------------------------------------------------------------
    # Step 4: Build GGN operator (reuse curvature subset)
    # -------------------------------------------------------------------------
    print("\n[Step 3] Building GGN operator...")
    curv_idx = load_curvature_subset(args.ifc_path)
    print(f"  Curvature subset size: {len(curv_idx):,}")
    
    curv_loader, local_samples, global_samples = make_distributed_subset_loader(
        args.imagenet,
        split="train",
        indices=curv_idx,
        batch_size=args.curv_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        drop_last=False,
    )
    
    ggn_operator = GGNOperator(
        model,
        curv_loader,
        damping=args.damping,
        device=args.device,
    )
    
    # Prepare operator
    with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        ggn_operator.prepare_for_solve()
    
    print(f"  GGN operator ready (damping={args.damping})")
    

    # -------------------------------------------------------------------------
    # Step 5: Create folds
    # -------------------------------------------------------------------------
    K = args.k_folds
    K_eval = args.k_eval
    
    print(f"\n[Step 4] Creating {K} folds...")
    fold_ids = create_kfold_partition(N, K, seed=args.seed)
    fold_sizes = get_fold_sizes(fold_ids, K)
    
    print(f"  Fold size stats: min={fold_sizes.min()}, max={fold_sizes.max()}, "
          f"mean={fold_sizes.mean():.1f}")
    
    # Save fold info
    np.save(os.path.join(args.output, "fold_ids.npy"), fold_ids)
    np.save(os.path.join(args.output, "fold_sizes.npy"), fold_sizes)
    
    # -------------------------------------------------------------------------
    # Step 6: Cache fold gradients for all folds
    # -------------------------------------------------------------------------
    print(f"\n[Step 5] Computing/loading fold gradients for all {K} folds...")
    cache_dir = os.path.join(args.output, "cache", "fold_grads")
    os.makedirs(cache_dir, exist_ok=True)
    
    # Check how many already cached
    cached_count = sum(1 for k in range(K) if os.path.exists(os.path.join(cache_dir, f"fold_{k}.pt")))
    print(f"  Already cached: {cached_count}/{K}")
    
    fold_grad_start = time.time()
    for k in tqdm(range(K), desc="Fold gradients"):
        continue
        # _ = load_or_compute_fold_grad(
        #     model, dataset, fold_ids, k, cache_dir,
        #     batch_size=args.grad_batch_size,
        #     microbatch_size=args.microbatch_size,
        #     num_workers=args.num_workers,
        #     device=args.device,
        # )
    fold_grad_time = time.time() - fold_grad_start
    print(f"  Fold gradient computation time: {fold_grad_time:.1f}s")
    
    # -------------------------------------------------------------------------
    # Step 7: Sample evaluation folds
    # -------------------------------------------------------------------------
    print(f"\n[Step 6] Sampling {K_eval} evaluation folds...")
    rng = np.random.default_rng(args.seed)
    query_folds = rng.choice(K, size=K_eval, replace=False)
    np.save(os.path.join(args.output, "query_folds.npy"), query_folds)
    print(f"  Query folds: {query_folds[:10]}{'...' if len(query_folds) > 10 else ''}")
    
    # -------------------------------------------------------------------------
    # Step 8 & 9: For each query fold, compute exact and CiF scores
    # -------------------------------------------------------------------------
    print(f"\n[Step 7] Running experiment ({K_eval} query folds)...")
    
    use_target_sum = True
    print(f"  Target sum weighting: {use_target_sum}")
    
    results = []
    all_spearman = []
    all_pearson = []
    total_exact_time = 0.0
    total_cif_time = 0.0
    
    # Open JSONL file for streaming results
    jsonl_path = os.path.join(args.output, "per_query.jsonl")
    with open(jsonl_path, "w") as jsonl_file:
        for i, v in enumerate(tqdm(query_folds, desc="Query folds")):
            # -------------------------------------------------------------
            # Load query fold gradient (typically MEAN gradient over the fold)
            # -------------------------------------------------------------
            g_v = load_or_compute_fold_grad(
                model, dataset, fold_ids, v, cache_dir,
                batch_size=args.grad_batch_size,
                microbatch_size=args.microbatch_size,
                num_workers=args.num_workers,
                device=args.device,
            )

            g_v_eff = g_v 

            # -------------------------------------------------------------
            # Exact IF: ONE PCG solve for s_v = (H+λI)^(-1) g_v_eff
            # -------------------------------------------------------------
            exact_start = time.time()
            with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
                s_v, solve_info = pcg_solve(
                    ggn_operator, g_v_eff.float(),
                    max_iter=args.max_cg_iter,
                    tol=args.cg_tol,
                    device=args.device,
                    verbose=(i == 0),
                )
            exact_solve_time = time.time() - exact_start
            total_exact_time += exact_solve_time

            # -------------------------------------------------------------
            # Exact scores: score(t) = - (g_t_eff)^T s_v
            # where g_t_eff is either Σ_{i in t} g_i (if use_target_sum)
            # or mean-gradient (if not)
            # -------------------------------------------------------------
            exact_scores = []
            target_folds = []

            for t in tqdm(range(K), desc="IF scores"):
                if t == v:
                    continue

                g_t = load_or_compute_fold_grad(
                    model, dataset, fold_ids, t, cache_dir,
                    batch_size=args.grad_batch_size,
                    microbatch_size=args.microbatch_size,
                    num_workers=args.num_workers,
                    device=args.device,
                )

                g_t_eff =  g_t * fold_sizes[t]
                score_exact = torch.dot(g_t_eff, s_v).item() / (N - len(get_fold_indices(fold_ids, t)))

                exact_scores.append(score_exact)
                target_folds.append(t)

            exact_scores = np.array(exact_scores, dtype=np.float64)

            # -------------------------------------------------------------
            # CiF scores: Δθ_t approximates - (H+λI)^(-1) Σ_{i in t} g_i
            # Score(t) = - (g_v_eff)^T Δθ_t
            # -------------------------------------------------------------
            cif_start = time.time()
            cif_scores = []

            for t in tqdm(target_folds, desc="CiF scores"):
                indices_t = get_fold_indices(fold_ids, t)
                delta_theta_t = query.compute_delta_theta(indices_t, cluster_ids)

                score_cif = torch.dot(g_v_eff, delta_theta_t).item()
                cif_scores.append(score_cif)

            query.clear_cache()

            cif_scores = np.array(cif_scores, dtype=np.float64)
            cif_time = time.time() - cif_start
            total_cif_time += cif_time

            # -------------------------------------------------------------
            # Rank correlation (make sure you use .statistic if desired)
            # -------------------------------------------------------------
            rho = spearmanr(exact_scores, cif_scores)
            r = pearsonr(exact_scores, cif_scores)

            all_spearman.append(rho)
            all_pearson.append(r)

            result = {
                "query_fold": int(v),
                "spearman": float(rho),
                "pearson": float(r),
                "exact_solve_sec": float(exact_solve_time),
                "cif_total_sec": float(cif_time),
                "cg_iterations": int(solve_info.get("iterations", -1)),
                "cg_converged": bool(solve_info.get("converged", False)),
                "cif_scores": cif_scores.tolist(),
                "exact_scores": exact_scores.tolist(),
            }
            results.append(result)

            jsonl_file.write(json.dumps(result) + "\n")
            jsonl_file.flush()

            print(
                f"  [{i+1}/{K_eval}] Fold {v}: Spearman={rho:.4f}, "
                f"Pearson={r:.4f}, exact={exact_solve_time:.1f}s, cif={cif_time:.1f}s"
            )

    
    # -------------------------------------------------------------------------
    # Step 11: Save summary
    # -------------------------------------------------------------------------
    print("\n[Step 8] Saving results...")
    
    all_spearman = np.array(all_spearman)
    all_pearson = np.array(all_pearson)
    
    summary = {
        'mode': 'truncated',
        'k_folds': K,
        'k_eval': K_eval,
        'seed': args.seed,
        'damping': args.damping,
        'max_cg_iter': args.max_cg_iter,
        'cg_tol': args.cg_tol,
        'use_target_sum': use_target_sum,
        'use_tiny': args.use_tiny,
        'dataset_size': N,
        'spearman_mean': float(all_spearman.mean()),
        'spearman_std': float(all_spearman.std()),
        'spearman_min': float(all_spearman.min()),
        'spearman_max': float(all_spearman.max()),
        'pearson_mean': float(all_pearson.mean()),
        'pearson_std': float(all_pearson.std()),
        'total_exact_time_sec': float(total_exact_time),
        'total_cif_time_sec': float(total_cif_time),
        'fold_grad_time_sec': float(fold_grad_time),
        'timestamp': datetime.now().isoformat(),
    }
    
    with open(os.path.join(args.output, "summary.json"), 'w') as f:
        json.dump(summary, f, indent=2)
    
    # -------------------------------------------------------------------------
    # Final report
    # -------------------------------------------------------------------------
    print("\n" + "=" * 60)
    print("EXPERIMENT COMPLETE")
    print("=" * 60)
    print(f"  K folds: {K}")
    print(f"  K eval: {K_eval}")
    print(f"  Spearman: {all_spearman.mean():.4f} ± {all_spearman.std():.4f}")
    print(f"  Pearson:  {all_pearson.mean():.4f} ± {all_pearson.std():.4f}")
    print(f"  Total exact IF time: {total_exact_time:.1f}s ({K_eval} solves)")
    print(f"  Total CiF time: {total_cif_time:.1f}s")
    print(f"  Fold gradient time: {fold_grad_time:.1f}s")
    print(f"  Output: {args.output}")
    print("=" * 60)


if __name__ == "__main__":
    main()
