"""
IFC Builder - Orchestrates CG solves and caching.

Main orchestration module that:
1. Loads curvature subset and builds GGN operator
2. Builds KFAC preconditioner
3. Solves (Ĝ + λI)z_c = ḡ_c for each cluster
4. Caches solutions z_c to disk

DDP Support:
    - Curvature batches are sharded across ranks (each rank sees local samples).
    - GGN matvec all-reduces the result so all ranks get the same vector.
    - PCG runs on all ranks with synchronized scalars (pAp, rz, rz_new).
    - Only rank 0 writes z_cluster/*.pt and index.pkl.
    - Resume status is broadcast so all ranks know which clusters to skip.
"""

import torch
from torch.backends.cuda import sdp_kernel
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Dict, List, Tuple
import numpy as np
import os
import pickle
from tqdm import tqdm
import time
import json
import csv
from datetime import datetime

from .vit_full import ViTWithHooks, load_vit
from .imagenet_loader import ImageNetDataset, make_distributed_subset_loader
from .fastif_select import load_curvature_subset
from .ggn_ops import GGNOperator
from .pcg import PCGSolver, pcg_solve
from .rhs_build import RHSManager, load_cluster_counts
from .logging_utils import (
    get_logger, log_dict, StageTimer,
    is_rank0, is_ddp, barrier, broadcast_object, rank0_print,
)

logger = get_logger(__name__)

# Threshold for Hx_over_lamx - if below this, GGN contribution is too small relative to damping
HX_OVER_LAMX_THRESHOLD = 0.3
# Damping grid for retry (decreasing order)
DAMPING_RETRY_GRID = [0.05]


def _to_json_safe(v):
    """Convert a value to a JSON-serializable type."""
    if isinstance(v, torch.Tensor):
        return v.item() if v.numel() == 1 else v.tolist()
    elif isinstance(v, np.ndarray):
        return v.item() if v.size == 1 else v.tolist()
    elif isinstance(v, (np.floating, np.integer)):
        return v.item()
    return v


def _load_previous_results(jsonl_path: str) -> Dict[int, dict]:
    """
    Load previous cluster solve results from JSONL file.
    
    Returns:
        Dict mapping cluster_id -> best result dict for that cluster
    """
    results = {}
    if not os.path.exists(jsonl_path):
        return results
    
    with open(jsonl_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
                if entry.get('type') != 'cluster_solve':
                    continue
                cluster_id = entry.get('cluster_id')
                if cluster_id is None:
                    continue
                
                # Keep the best result for each cluster (highest Hx_over_lamx)
                hx_ratio = entry.get('Hx_over_lamx', 0)
                if cluster_id not in results or hx_ratio > results[cluster_id].get('Hx_over_lamx', 0):
                    results[cluster_id] = entry
            except json.JSONDecodeError:
                continue
    
    return results


def _needs_resolve(prev_result: Optional[dict], threshold: float = HX_OVER_LAMX_THRESHOLD) -> bool:
    """
    Check if a cluster needs to be re-solved.
    
    Returns True if:
    - No previous result exists
    - Previous result has Hx_over_lamx < threshold
    """
    if prev_result is None:
        return True
    hx_ratio = prev_result.get('Hx_over_lamx', 0)
    return hx_ratio < threshold


class JacobiPreconditioner:
    def __init__(self, diag: torch.Tensor, eps: float = 0.0):
        self.diag = diag
        self.eps = float(eps)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return x / (self.diag + self.eps)

class IFCBuilder:
    """
    Influence Function Cache Builder.
    
    Orchestrates the full pipeline for building cluster-level
    influence function solutions.
    """
    
    def __init__(
        self,
        model: ViTWithHooks,
        output_dir: str,
        imagenet_root: str,
        n_clusters: int,
        damping: float = 1e-2,
        precond_type: str = "kfac",
        max_cg_iter: int = 50,
        cg_tol: float = 1e-3,
        device: str = "cuda",
    ):
        """
        Args:
            model: ViT model with hooks
            output_dir: Directory for all outputs
            imagenet_root: Path to ImageNet
            n_clusters: Number of clusters
            damping: GGN damping parameter λ
            precond_type: "kfac", "ekfac", or "diagonal"
            max_cg_iter: Maximum CG iterations
            cg_tol: CG convergence tolerance
            device: Computation device
        """
        self.model = model
        self.output_dir = output_dir
        self.imagenet_root = imagenet_root
        self.n_clusters = n_clusters
        self.damping = damping
        self.precond_type = precond_type
        self.max_cg_iter = max_cg_iter
        self.cg_tol = cg_tol
        self.device = device
        
        # Paths
        self.z_cluster_dir = os.path.join(output_dir, "z_cluster")
        self.index_path = os.path.join(output_dir, "index.pkl")
        
        # Components (lazy initialization)
        self._ggn_operator: Optional[GGNOperator] = None
        self._preconditioner = None
        self._rhs_manager: Optional[RHSManager] = None
        self._pcg_solver: Optional[PCGSolver] = None
        
        # Tracking
        self._solved_clusters: set = set()
        self._load_index()
        
        logger.info(f"IFCBuilder initialized")
        log_dict(logger, "IFCBuilder config", {
            'n_clusters': n_clusters,
            'damping': damping,
            'precond_type': precond_type,
            'max_cg_iter': max_cg_iter,
            'cg_tol': cg_tol,
            'output_dir': output_dir,
        })
        
    def _load_index(self):
        """Load index of solved clusters.
        
        DDP: Rank 0 loads and broadcasts to other ranks.
        """
        solved_set = set()
        
        if is_rank0():
            if os.path.exists(self.index_path):
                with open(self.index_path, 'rb') as f:
                    index = pickle.load(f)
                solved_set = set(index.get('solved', []))
                logger.info(f"Loaded index: {len(solved_set)} clusters already solved")
        
        # Broadcast solved clusters in DDP mode
        if is_ddp():
            solved_list = broadcast_object(list(solved_set))
            solved_set = set(solved_list)
        
        self._solved_clusters = solved_set
            
    def _save_index(self):
        """Save index of solved clusters.
        
        DDP: Only rank 0 saves to disk.
        """
        if is_ddp() and not is_rank0():
            return
            
        index = {
            'solved': list(self._solved_clusters),
            'n_clusters': self.n_clusters,
            'damping': self.damping,
            'precond_type': self.precond_type,
        }
        with open(self.index_path, 'wb') as f:
            pickle.dump(index, f)
            
    def setup_ggn_operator(
        self,
        curv_batch_size: int = 16,
        num_workers: int = 8,
    ):
        """
        Set up GGN operator using curvature subset.
        
        DDP: Uses DistributedSampler to shard curvature batches across ranks.
              Each rank processes local samples; matvec all-reduces the result.
        
        Args:
            curv_batch_size: Batch size for GGN computation
            num_workers: Data loading workers
        """
        logger.info("Setting up GGN operator...")
        
        # Load curvature subset indices
        curv_idx = load_curvature_subset(self.output_dir)
        logger.info(f"  Curvature subset size: {len(curv_idx):,}")
        
        # Create dataloader for curvature subset
        # In DDP mode, use distributed sampler to shard across ranks
        if is_ddp():
            curv_loader, local_n, global_n = make_distributed_subset_loader(
                root=self.imagenet_root,
                split="train",
                indices=curv_idx,
                batch_size=curv_batch_size,
                num_workers=num_workers,
            )
            logger.info(f"  DDP: local samples={local_n}, global samples={global_n}")
        else:
            dataset = ImageNetDataset(self.imagenet_root, split="train")
            curv_loader = dataset.get_subset_loader(
                curv_idx,
                batch_size=curv_batch_size,
                shuffle=False,
                num_workers=num_workers,
            )
        # self.model.eval()
        # Create GGN operator
        self._ggn_operator = GGNOperator(
            self.model,
            curv_loader,
            damping=self.damping,
            device=self.device,
        )

        with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            self._ggn_operator.prepare_for_solve()

        # self._ggn_operator = FastGGNOperator(
        #     self.model,
        #     curv_loader,
        #     damping=self.damping,
        #     device=self.device,
        # )

        # logger.info("  Caching curvature blocks...")
        # self._ggn_operator.cache_curvature_block(max_samples=5000, require_divisible=True)

        # logger.info("  Building GGN operator...")
        # self._ggn_operator.build()
        # # Cache batches for faster repeated MVPs
        # logger.info("  Caching curvature batches...")
        # self._ggn_operator.cache_batches()
        
        logger.info(f"  GGN operator ready (damping={self.damping})")

    def setup_preconditioner(
        self,
        curv_batch_size: int = 16,
        num_workers: int = 8,
        max_batches: Optional[int] = None,
    ):
        """
        DDP: Preconditioner is built by rank 0 only, others load after barrier.
        
        Args:
            curv_batch_size: Batch size
            num_workers: Data loading workers
            max_batches: Maximum batches for factor computation
        """
        return
        rank0_print(f"Setting up {self.precond_type.upper()} preconditioner...")
        logger.info(f"Setting up {self.precond_type.upper()} preconditioner...")
        

        # Load curvature subset
        curv_idx = load_curvature_subset(self.output_dir)
        
        # Create dataloader (rank 0 uses all data for preconditioner)
        dataset = ImageNetDataset(self.imagenet_root, split="train")
        curv_loader = dataset.get_subset_loader(
            curv_idx,
            batch_size=curv_batch_size,
            shuffle=False,
            num_workers=num_workers,
        )
        if self.precond_type.lower() != "none":
            diagA = self._ggn_operator._diag_ggn
            self.M_inv = JacobiPreconditioner(diagA, eps=1e-6)
        rank0_print(f"  Preconditioner ready")
        
    def setup_rhs_manager(self):
        """Set up RHS (cluster gradient) manager."""
        self._rhs_manager = RHSManager(
            self.output_dir,
            self.n_clusters,
            device=self.device,
            cache_size=20,
        )
        
        if not self._rhs_manager.is_complete:
            raise RuntimeError(
                "RHS vectors not complete. Run stage 'build_rhs' first."
            )
        
        logger.info(f"RHS manager ready ({self.n_clusters} clusters)")

    def symmetry_error(self, A, n=3):
        for _ in range(n):
            u = torch.randn(A.num_params, device=A.device)
            v = torch.randn(A.num_params, device=A.device)

            Au = A.matvec_ggn_only(u)
            Av = A.matvec_ggn_only(v)
            num = (u.dot(Av) - v.dot(Au)).abs()
            den = (u.dot(Av)).abs() + (v.dot(Au)).abs() + 1e-12
            print("sym_err matvec_ggn_only:", (num/den).item())

            Au = A.matvec(u)
            Av = A.matvec(v)
            num = (u.dot(Av) - v.dot(Au)).abs()
            den = (u.dot(Av)).abs() + (v.dot(Au)).abs() + 1e-12

            print("sym_err matvec:", (num/den).item())

    def setup_pcg_solver(self):
        """Set up PCG solver with warm-start capability."""
        if self._ggn_operator is None:
            raise RuntimeError("Must call setup_ggn_operator first")
        
        
        self._pcg_solver = PCGSolver(
            self._ggn_operator,
            M_inv=None,
            max_iter=self.max_cg_iter,
            tol=self.cg_tol,
            device=self.device,
        )

        #     self.symmetry_error(self._ggn_operator, n=5)
        
        logger.info(f"PCG solver ready (max_iter={self.max_cg_iter}, tol={self.cg_tol})")
        
    def solve_cluster(
        self,
        cluster_id: int,
        use_warm_start: bool = False,
        verbose: bool = True,
        x0: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, dict]:
        """
        Solve for a single cluster's influence vector.
        
        Solves (Ĝ + λI)z_c = ḡ_c
        
        DDP: All ranks participate in PCG (for synchronized matvec).
             All ranks get the same solution due to synchronized scalars.
        
        Args:
            cluster_id: Cluster ID
            use_warm_start: Use warm-start from nearest solved cluster
            verbose: Print progress
            x0: Optional initial guess (e.g., previous solution)
            
        Returns:
            z_c: Solution vector
            info: Convergence information
        """
        if self._pcg_solver is None:
            raise RuntimeError("Must call setup_pcg_solver first")
        if self._rhs_manager is None:
            raise RuntimeError("Must call setup_rhs_manager first")
        
        # Get RHS (cluster mean gradient)
        rhs = self._rhs_manager.get(cluster_id)
        
        # Get centroid for warm-start (if available and no x0 provided)
        centroids_path = os.path.join(self.output_dir, "centroids.pt")
        centroid = None
        if use_warm_start and x0 is None and os.path.exists(centroids_path):
            centroids = torch.load(centroids_path, map_location='cpu')
            if isinstance(centroids, dict) and 'centroids' in centroids:
                centroids = centroids['centroids']
            if cluster_id < len(centroids):
                centroid = centroids[cluster_id]
        
        # Solve
        if verbose:
            rank0_print(f"Solving cluster {cluster_id}...")
        
        z_c, info = self._pcg_solver.solve(
            rhs,
            cluster_id=cluster_id,
            centroid=centroid,
            use_warm_start=use_warm_start,
            verbose=verbose,
            x0=x0,
        )
        
        return z_c, info
    
    def solve_cluster_with_damping_retry(
        self,
        cluster_id: int,
        damping_grid: List[float],
        threshold: float = HX_OVER_LAMX_THRESHOLD,
        use_warm_start: bool = False,
        verbose: bool = True,
        x0: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, dict, float]:
        """
        Solve cluster with damping retry if Hx_over_lamx is below threshold.
        
        Tries each damping value in the grid until either:
        - Converged with Hx_over_lamx >= threshold, or
        - All damping values exhausted (returns best result)
        
        Args:
            cluster_id: Cluster ID
            damping_grid: List of damping values to try (in order)
            threshold: Hx_over_lamx threshold
            use_warm_start: Use warm-start
            verbose: Print progress
            x0: Optional initial guess (e.g., previous solution)
            
        Returns:
            z_c: Best solution vector
            info: Best convergence info
            used_damping: The damping value that was used
        """
        best_z = None
        best_info = None
        best_damping = self.damping
        best_hx_ratio = -1.0
        
        # First try with current damping (use x0 as initial guess)
        z_c, info = self.solve_cluster(cluster_id, use_warm_start, verbose, x0=x0)
        hx_ratio = info.get('Hx_over_lamx', 0)
        
        if hx_ratio >= threshold and info.get('converged', False):
            info['used_damping'] = self.damping
            return z_c, info, self.damping
        
        # Track best result so far
        best_z, best_info, best_damping, best_hx_ratio = z_c, info, self.damping, hx_ratio
        
        # Try other damping values
        original_damping = self.damping
        for new_damping in damping_grid:
            if new_damping == original_damping:
                continue  # Already tried
            
            if verbose:
                rank0_print(f"  Retrying cluster {cluster_id} with damping={new_damping} (Hx/λx={hx_ratio:.3f} < {threshold})")
            
            # Update damping in GGN operator
            self._ggn_operator.damping = new_damping
            self.damping = new_damping
            
            # Re-solve (use best solution so far as warm start)
            z_c, info = self.solve_cluster(cluster_id, use_warm_start, verbose=False, x0=best_z)
            hx_ratio = info.get('Hx_over_lamx', 0)
            
            if verbose:
                rank0_print(f"    -> damping={new_damping}: Hx/λx={hx_ratio:.3f}, converged={info.get('converged', False)}")
            
            # Check if this is better
            if hx_ratio > best_hx_ratio:
                best_z, best_info, best_damping, best_hx_ratio = z_c, info, new_damping, hx_ratio
            
            # Check if good enough
            if hx_ratio >= threshold and info.get('converged', False):
                best_z, best_info, best_damping = z_c, info, new_damping
                break
        
        # Restore original damping setting
        self._ggn_operator.damping = original_damping
        self.damping = original_damping
        
        best_info['used_damping'] = best_damping
        if verbose:
            rank0_print(f"  Best result for cluster {cluster_id}: damping={best_damping}, Hx/λx={best_hx_ratio:.3f}")
        
        return best_z, best_info, best_damping
    
    def save_solution(self, cluster_id: int, z_c: torch.Tensor):
        """Save cluster solution to disk.
        
        DDP: Only rank 0 saves to disk. All ranks update in-memory index.
        """
        # All ranks update in-memory tracking
        self._solved_clusters.add(cluster_id)
        
        # Only rank 0 writes to disk
        if is_ddp() and not is_rank0():
            return
            
        os.makedirs(self.z_cluster_dir, exist_ok=True)
        
        # Save as half precision
        save_path = os.path.join(self.z_cluster_dir, f"{cluster_id}.pt")
        torch.save(z_c.cpu(), save_path)
        
        # Update index
        self._save_index()
        
    def load_solution(self, cluster_id: int) -> torch.Tensor:
        """Load cluster solution from disk."""
        path = os.path.join(self.z_cluster_dir, f"{cluster_id}.pt")
        return torch.load(path, map_location='cpu').float()
    
    def is_solved(self, cluster_id: int) -> bool:
        """Check if cluster has been solved."""
        return cluster_id in self._solved_clusters
    
    def build_all(
        self,
        curv_batch_size: int = 16,
        num_workers: int = 8,
        resume: bool = True,
        verbose: bool = True,
        hx_threshold: float = HX_OVER_LAMX_THRESHOLD,
        damping_retry_grid: Optional[List[float]] = None,
    ):
        """
        Build IFC for all clusters.
        
        DDP: All ranks participate in PCG solves (for synchronized matvec).
             Only rank 0 saves solutions and shows progress bar.
        
        Smart resume logic:
        - Loads previous results from JSONL log
        - Skips clusters with Hx_over_lamx >= threshold
        - Re-runs clusters with Hx_over_lamx < threshold using damping retry grid
        
        Args:
            curv_batch_size: Batch size for GGN
            num_workers: Data loading workers
            resume: Skip already solved clusters (with sufficient Hx_over_lamx)
            verbose: Print progress
            hx_threshold: Threshold for Hx_over_lamx to consider a cluster "done"
            damping_retry_grid: Damping values to try if Hx_over_lamx too low
        """
        if damping_retry_grid is None:
            damping_retry_grid = DAMPING_RETRY_GRID
        
        logger.info("=" * 60)
        logger.info("Building Influence Function Cache")
        if is_ddp():
            import torch.distributed as dist
            logger.info(f"  DDP mode: rank {dist.get_rank()} / {dist.get_world_size()}")
        logger.info("=" * 60)
        
        stage_timer = StageTimer()
        
        # Setup components
        with stage_timer.stage("setup_ggn"):
            self.setup_ggn_operator(curv_batch_size, num_workers)
        if self.precond_type.lower() != "none":
            with stage_timer.stage("setup_preconditioner"):
                self.setup_preconditioner(curv_batch_size, num_workers)
        else:
            logger.info("Skipping preconditioner setup as 'none' was selected.")
        with stage_timer.stage("setup_rhs"):
            self.setup_rhs_manager()
        with stage_timer.stage("setup_pcg"):
            self.setup_pcg_solver()
        
        # Get cluster ordering (by size for efficiency)
        cluster_counts = load_cluster_counts(self.output_dir)
        cluster_order = np.argsort(-cluster_counts)  # Largest first
        
        # Setup CSV/JSON logging paths
        log_csv_path = os.path.join(self.output_dir, "ifc_build_log.csv")
        log_jsonl_path = os.path.join(self.output_dir, "ifc_build_log.jsonl")
        
        # Load previous results for smart resume (rank 0 loads and broadcasts)
        prev_results = {}
        if resume and is_rank0():
            prev_results = _load_previous_results(log_jsonl_path)
            logger.info(f"Loaded {len(prev_results)} previous cluster results from JSONL")
            
            # Count how many need re-solving
            needs_work = sum(1 for cid in range(self.n_clusters) 
                          if _needs_resolve(prev_results.get(cid), hx_threshold))
            already_good = self.n_clusters - needs_work
            logger.info(f"  Clusters with Hx/λx >= {hx_threshold}: {already_good}")
            logger.info(f"  Clusters needing solve/re-solve: {needs_work}")
        
        # Broadcast prev_results in DDP mode
        if is_ddp():
            prev_results = broadcast_object(prev_results)
        
        # Tracking
        total_time = 0
        total_iters = 0
        solved_count = 0
        converged_count = 0
        solve_times = []
        skipped_count = 0
        retried_count = 0
        
        # Setup logging files (rank 0 only)
        csv_file = None
        csv_writer = None
        jsonl_file = None
        
        if is_rank0():
            # Write run metadata to JSON
            run_metadata = {
                'type': 'run_start',
                'timestamp': datetime.now().isoformat(),
                'n_clusters': self.n_clusters,
                'damping': self.damping,
                'precond_type': self.precond_type,
                'max_cg_iter': self.max_cg_iter,
                'cg_tol': self.cg_tol,
                'curv_batch_size': curv_batch_size,
                'num_workers': num_workers,
                'resume': resume,
                'hx_threshold': hx_threshold,
                'damping_retry_grid': damping_retry_grid,
                'already_solved': len(self._solved_clusters),
                'prev_results_loaded': len(prev_results),
            }
            jsonl_file = open(log_jsonl_path, 'a')
            jsonl_file.write(json.dumps(run_metadata) + '\n')
            jsonl_file.flush()
            
            # Setup CSV with header if new file
            csv_exists = os.path.exists(log_csv_path)
            csv_file = open(log_csv_path, 'a', newline='')
            csv_writer = csv.DictWriter(csv_file, fieldnames=[
                'timestamp', 'cluster_id', 'solve_time_sec', 'iterations',
                'converged', 'initial_residual', 'final_residual',
                'cluster_size', 'rhs_norm', 'used_damping'
            ])
            if not csv_exists:
                csv_writer.writeheader()
                csv_file.flush()
        
        logger.info(f"Solving {self.n_clusters} clusters...")
        
        # Use rank-aware progress bar (only rank 0 shows)
        from .logging_utils import RankAwareTqdm
        pbar = RankAwareTqdm(cluster_order, desc="Building IFC")
        
        for cluster_id in pbar:
            cluster_id = int(cluster_id)
            prev_result = prev_results.get(cluster_id)
            
            # Check if we can skip this cluster
            if resume and not _needs_resolve(prev_result, hx_threshold):
                skipped_count += 1
                solved_count += 1
                continue
            
            # Check if this is a retry (has previous result but Hx_over_lamx was too low)
            is_retry = prev_result is not None
            if is_retry:
                retried_count += 1
                if verbose:
                    prev_hx = prev_result.get('Hx_over_lamx', 0)
                    rank0_print(f"Re-solving cluster {cluster_id} (prev Hx/λx={prev_hx:.3f} < {hx_threshold})")
            
            # Load previous solution as warm start if retrying
            prev_solution = None
            if is_retry and self.is_solved(cluster_id):
                try:
                    prev_solution = self.load_solution(cluster_id).to(self.device)
                    if verbose:
                        rank0_print(f"  Using previous solution as warm start (norm={prev_solution.norm().item():.4f})")
                except Exception as e:
                    logger.warning(f"Could not load previous solution for cluster {cluster_id}: {e}")
            
            # Solve with damping retry if needed
            start_time = time.time()
            
            if is_retry or not resume:
                # Use damping retry for problematic clusters
                z_c, info, used_damping = self.solve_cluster_with_damping_retry(
                    cluster_id,
                    damping_grid=damping_retry_grid,
                    threshold=hx_threshold,
                    use_warm_start=False,
                    verbose=verbose,
                    x0=prev_solution,
                )
            else:
                # First attempt for new cluster
                z_c, info = self.solve_cluster(
                    cluster_id,
                    use_warm_start=False,
                    verbose=verbose,
                )
                used_damping = self.damping
                info['used_damping'] = used_damping
                
                # Check if we need to retry with different damping
                hx_ratio = info.get('Hx_over_lamx', 0)
                if hx_ratio < hx_threshold:
                    if verbose:
                        rank0_print(f"  Initial Hx/λx={hx_ratio:.3f} < {hx_threshold}, trying damping grid...")
                    z_c, info, used_damping = self.solve_cluster_with_damping_retry(
                        cluster_id,
                        damping_grid=damping_retry_grid,
                        threshold=hx_threshold,
                        use_warm_start=False,
                        verbose=verbose,
                    )
            
            elapsed = time.time() - start_time
            
            # Save
            self.save_solution(cluster_id, z_c)
            
            # Update tracking
            total_time += elapsed
            total_iters += info['iterations']
            solved_count += 1
            solve_times.append(elapsed)
            if info.get('converged', False):
                converged_count += 1
            
            # Log to CSV/JSONL (rank 0 only)
            if is_rank0():
                cluster_size = cluster_counts[cluster_id] if cluster_id < len(cluster_counts) else 0
                residuals_list = info.get('residuals', [])
                log_entry = {
                    'timestamp': datetime.now().isoformat(),
                    'cluster_id': cluster_id,
                    'solve_time_sec': round(elapsed, 4),
                    'iterations': info['iterations'],
                    'converged': info.get('converged', False),
                    'initial_residual': _to_json_safe(residuals_list[0]) if residuals_list else None,
                    'final_residual': _to_json_safe(info.get('residual', None)),
                    'cluster_size': int(cluster_size),
                    'rhs_norm': _to_json_safe(info.get('rhs_norm', None)),
                    'used_damping': _to_json_safe(info.get('used_damping', self.damping)),
                }
                
                # Write to CSV
                if csv_writer:
                    csv_writer.writerow(log_entry)
                    csv_file.flush()
                
                # Write to JSONL (includes any extra info fields)
                if jsonl_file:
                    full_entry = {'type': 'cluster_solve', **log_entry}
                    # Add any extra info from PCG solver
                    for k, v in info.items():
                        if k not in log_entry and k != 'residuals':  # skip large lists
                            safe_v = _to_json_safe(v)
                            if isinstance(safe_v, (int, float, bool, str, type(None))):
                                full_entry[k] = safe_v
                    jsonl_file.write(json.dumps(full_entry) + '\n')
                    jsonl_file.flush()
            
            if verbose and solved_count % 50 == 0:
                avg_time = total_time / max(1, len(solve_times))
                avg_iters = total_iters / max(1, len(solve_times))
                logger.info(f"  Solved {solved_count}/{self.n_clusters}, "
                      f"avg time: {avg_time:.1f}s, avg iters: {avg_iters:.1f}")
        
        # Log final statistics
        newly_solved = max(1, len(solve_times))  # clusters solved this run
        final_stats = {
            'total_clusters': self.n_clusters,
            'newly_solved': len(solve_times),
            'skipped_good': skipped_count,
            'retried_low_hx': retried_count,
            'total_solved': solved_count,
            'total_time_hours': total_time / 3600,
            'avg_time_per_cluster': total_time / newly_solved if solve_times else 0,
            'avg_cg_iterations': total_iters / newly_solved if solve_times else 0,
            'convergence_rate': converged_count / newly_solved if solve_times else 0,
            'median_solve_time': float(np.median(solve_times)) if solve_times else 0,
            'min_solve_time': float(np.min(solve_times)) if solve_times else 0,
            'max_solve_time': float(np.max(solve_times)) if solve_times else 0,
            'std_solve_time': float(np.std(solve_times)) if solve_times else 0,
            'hx_threshold': hx_threshold,
        }
        log_dict(logger, "IFC build complete", final_stats)
        
        # Write final summary to JSONL (rank 0 only)
        if is_rank0():
            if jsonl_file:
                summary_entry = {
                    'type': 'run_complete',
                    'timestamp': datetime.now().isoformat(),
                    **final_stats,
                }
                jsonl_file.write(json.dumps(summary_entry) + '\n')
                jsonl_file.close()
            if csv_file:
                csv_file.close()
            logger.info(f"Logs written to {log_csv_path} and {log_jsonl_path}")
        
        # Log PCG statistics
        if self._pcg_solver:
            self._pcg_solver.get_statistics()
        
        # Log stage timing
        stage_timer.log_summary(logger)
        
    def get_all_solutions(self) -> Dict[int, torch.Tensor]:
        """Load all cluster solutions."""
        solutions = {}
        for c in range(self.n_clusters):
            if self.is_solved(c):
                solutions[c] = self.load_solution(c)
        return solutions


def build_ifc(
    imagenet_root: str,
    output_dir: str,
    n_clusters: int = 500,
    damping: float = 1e-2,
    precond_type: str = "kfac",
    max_cg_iter: int = 50,
    cg_tol: float = 1e-3,
    curv_batch_size: int = 16,
    num_workers: int = 8,
    device: str = "cuda",
    resume: bool = True,
    use_tiny: bool = True,
):
    """
    High-level function to build IFC.
    
    Args:
        imagenet_root: Path to ImageNet
        output_dir: Output directory
        n_clusters: Number of clusters
        damping: GGN damping
        precond_type: Preconditioner type
        max_cg_iter: Max CG iterations
        cg_tol: CG tolerance
        curv_batch_size: Batch size for curvature
        num_workers: Data loading workers
        device: Computation device
        resume: Resume from previous run
    """
    # Load model
    model = load_vit(pretrained=True, device=device, use_tiny=use_tiny)
    model.register_kfac_hooks()
    
    # Create builder
    builder = IFCBuilder(
        model,
        output_dir,
        imagenet_root,
        n_clusters,
        damping=damping,
        precond_type=precond_type,
        max_cg_iter=max_cg_iter,
        cg_tol=cg_tol,
        device=device,
    )
    
    # Build
    builder.build_all(
        curv_batch_size=curv_batch_size,
        num_workers=num_workers,
        resume=resume,
    )
    
    return builder


def load_ifc(output_dir: str, device: str = "cuda") -> IFCBuilder:
    """
    Load a pre-built IFC.
    
    Args:
        output_dir: Directory with IFC data
        device: Device to use
        
    Returns:
        IFCBuilder instance with loaded data
    """
    # Load index to get parameters
    index_path = os.path.join(output_dir, "index.pkl")
    with open(index_path, 'rb') as f:
        index = pickle.load(f)
    
    n_clusters = index['n_clusters']
    damping = index['damping']
    precond_type = index['precond_type']
    
    # Create minimal builder (no model needed for queries)
    builder = IFCBuilder.__new__(IFCBuilder)
    builder.output_dir = output_dir
    builder.n_clusters = n_clusters
    builder.damping = damping
    builder.precond_type = precond_type
    builder.device = device
    builder.z_cluster_dir = os.path.join(output_dir, "z_cluster")
    builder.index_path = index_path
    builder._solved_clusters = set(index.get('solved', []))
    builder._rhs_manager = RHSManager(output_dir, n_clusters, device, cache_size=20)
    
    rank0_print(f"Loaded IFC from {output_dir}")
    rank0_print(f"  Clusters: {n_clusters}")
    rank0_print(f"  Solved: {len(builder._solved_clusters)}")
    
    return builder


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python ifc_build.py /path/to/imagenet /path/to/output [n_clusters]")
        sys.exit(1)
    
    imagenet_root = sys.argv[1]
    output_dir = sys.argv[2]
    n_clusters = int(sys.argv[3]) if len(sys.argv) > 3 else 500
    
    build_ifc(
        imagenet_root,
        output_dir,
        n_clusters=n_clusters,
        damping=1e-2,
        precond_type="kfac",
    )
