"""
Cluster mean gradient accumulation (RHS builder).

Computes the mean gradient for each cluster:
    ḡ_c = (1/n_c) Σ_{i∈c} ∇_θ ℓ_i(θ)

These serve as the right-hand sides for CG solves.

DDP Support:
    - In DDP mode, RHS building runs only on rank 0.
    - All ranks call barrier() after rank 0 finishes.
    - Other ranks can load the saved files after the barrier.
"""

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Dict
import numpy as np
import os
from tqdm import tqdm

from .vit_full import ViTWithHooks
from .imagenet_loader import ImageNetDataset, get_indexed_loader
from .logging_utils import (
    get_logger, log_tensor_stats, log_dict,
    is_rank0, is_ddp, barrier, rank0_print,
)
from torch.func import functional_call, grad, vmap
logger = get_logger(__name__)


def accumulate_cluster_gradients(
    model: ViTWithHooks,
    dataloader: DataLoader,
    cluster_ids: np.ndarray,
    output_dir: str,
    n_clusters: int,
    device: str = "cuda",
    dtype: torch.dtype = torch.float16,
):
    """
    Accumulate gradients for each cluster.
    
    For each cluster c:
    1. Iterate through training data
    2. For samples in cluster c, accumulate gradients
    3. Divide by cluster size to get mean
    4. Save to disk
    
    Args:
        model: ViT model
        dataloader: Indexed DataLoader returning (idx, image, label)
        cluster_ids: Array mapping sample index to cluster ID
        output_dir: Directory to save cluster gradients
        n_clusters: Number of clusters
        device: Computation device
        dtype: Storage dtype for gradients
    """
    rhs_dir = os.path.join(output_dir, "rhs_cluster")
    os.makedirs(rhs_dir, exist_ok=True)
    
    model.model.eval()
    params = list(model.model.parameters())
    
    for p in params:
        p.requires_grad = True
    
    num_params = model.num_params
    
    # Initialize accumulators on CPU to save GPU memory
    grad_sums = {c: torch.zeros(num_params, dtype=torch.float32) for c in range(n_clusters)}
    cluster_counts = np.zeros(n_clusters, dtype=np.int32)
    
    logger.info(f"Accumulating gradients for {n_clusters} clusters, {num_params:,} params")
    log_dict(logger, "RHS accumulation config", {
        'n_clusters': n_clusters,
        'num_params': num_params,
        'output_dir': output_dir,
        'dtype': str(dtype),
    })
    
    grad_norms = []  # Track gradient norms for statistics
    
    for idx, images, labels in tqdm(dataloader, desc="Accumulating gradients"):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = images.shape[0]
        
        # Process each sample
        for i in range(batch_size):
            sample_idx = idx[i].item()
            cluster = cluster_ids[sample_idx]
            
            # Forward pass
            logits = model(images[i:i+1])
            loss = F.cross_entropy(logits, labels[i:i+1])
            
            # Backward pass
            model.model.zero_grad()
            loss.backward()
            
            # Get gradient and accumulate
            grad_vec = model.get_grad_vector()
            grad_sums[cluster] += grad_vec.detach().cpu().float()
            cluster_counts[cluster] += 1
            grad_norms.append(grad_vec.norm().item())
    
    # Log gradient statistics
    grad_norms = np.array(grad_norms)
    log_dict(logger, "Individual gradient statistics", {
        'n_samples': len(grad_norms),
        'grad_norm_min': float(grad_norms.min()),
        'grad_norm_max': float(grad_norms.max()),
        'grad_norm_mean': float(grad_norms.mean()),
        'grad_norm_std': float(grad_norms.std()),
        'grad_norm_median': float(np.median(grad_norms)),
    })
    
    # Compute means and save
    logger.info("Computing means and saving...")
    
    mean_grad_norms = []
    empty_clusters = 0
    
    for c in tqdm(range(n_clusters), desc="Saving cluster gradients"):
        if cluster_counts[c] > 0:
            mean_grad = grad_sums[c] / cluster_counts[c]
            mean_grad_norms.append(mean_grad.norm().item())
        else:
            mean_grad = torch.zeros(num_params)
            empty_clusters += 1
        
        # Save as half precision
        save_path = os.path.join(rhs_dir, f"{c}.pt")
        torch.save(mean_grad.to(dtype), save_path)
        
        # Clear from memory
        del grad_sums[c]
    
    # Log cluster mean gradient statistics
    if mean_grad_norms:
        mean_grad_norms = np.array(mean_grad_norms)
        log_dict(logger, "Cluster mean gradient statistics", {
            'n_clusters': n_clusters,
            'empty_clusters': empty_clusters,
            'mean_grad_norm_min': float(mean_grad_norms.min()),
            'mean_grad_norm_max': float(mean_grad_norms.max()),
            'mean_grad_norm_mean': float(mean_grad_norms.mean()),
            'mean_grad_norm_std': float(mean_grad_norms.std()),
        })
    
    # Log cluster size distribution
    log_dict(logger, "Cluster size distribution", {
        'min_size': int(cluster_counts[cluster_counts > 0].min()) if (cluster_counts > 0).any() else 0,
        'max_size': int(cluster_counts.max()),
        'mean_size': float(cluster_counts[cluster_counts > 0].mean()) if (cluster_counts > 0).any() else 0,
        'std_size': float(cluster_counts[cluster_counts > 0].std()) if (cluster_counts > 0).any() else 0,
    })
    
    # Save cluster counts
    counts_path = os.path.join(output_dir, "cluster_counts.npy")
    np.save(counts_path, cluster_counts)
    
    logger.info(f"Saved {n_clusters} cluster mean gradients to {rhs_dir}")
    logger.info(f"Saved cluster counts to {counts_path}")

def _flatten_grads(params):
    # robust: handles None grads
    chunks = []
    for p in params:
        if p.grad is None:
            chunks.append(torch.zeros(p.numel(), device=p.device, dtype=torch.float32))
        else:
            chunks.append(p.grad.detach().reshape(-1).to(torch.float32))
    return torch.cat(chunks, dim=0)
def accumulate_cluster_gradients_streaming(
    model,                    # ViTWithHooks
    dataset,                  # ImageNetDataset-like, must have get_subset_loader
    cluster_ids: np.ndarray,
    output_dir: str,
    n_clusters: int,
    batch_size: int = 64,
    microbatch_size: int = 8,     # IMPORTANT: keeps activation memory low
    num_workers: int = 8,
    device: str = "cuda",
    store_dtype: torch.dtype = torch.float16,  # on-disk dtype
    use_amp: bool = True,                      # forward/backward autocast
):
    """
    Streaming mean-gradient (RHS) builder per cluster.

    Key idea: mean gradient = grad of summed loss over samples.
    So we do ONE backward per microbatch (not per sample, not per-sample grads).
    """

    rhs_dir = os.path.join(output_dir, "rhs_cluster")
    os.makedirs(rhs_dir, exist_ok=True)

    # eval mode, but grads enabled
    model.model.eval().to(device)
    for p in model.model.parameters():
        p.requires_grad_(True)

    params = [p for p in model.model.parameters() if p.requires_grad]
    num_params = sum(p.numel() for p in params)

    # counts
    cluster_counts = np.bincount(cluster_ids, minlength=n_clusters).astype(np.int32)
    print(f"Processing {n_clusters} clusters (streaming mode)...")

    autocast_dtype = torch.float16 if store_dtype == torch.float16 else (
        torch.bfloat16 if store_dtype == torch.bfloat16 else torch.float16
    )

    for c in tqdm(range(n_clusters), desc="Processing clusters"):
        idx = np.where(cluster_ids == c)[0]
        if idx.size == 0:
            torch.save(torch.zeros(num_params, dtype=store_dtype),
                       os.path.join(rhs_dir, f"{c}.pt"))
            continue

        loader = dataset.get_subset_loader(
            idx,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
        )

        grad_sum = torch.zeros(num_params, device=device, dtype=torch.float32)
        count = 0

        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            B = images.size(0)
            # microbatch to control activation memory
            for s in range(0, B, microbatch_size):
                xb = images[s:s + microbatch_size]
                yb = labels[s:s + microbatch_size]

                model.model.zero_grad(set_to_none=True)

                # forward + backward once per microbatch
                logits = model(xb)  # wrapper forward; should return (mb, num_classes)
                loss = F.cross_entropy(logits, yb, reduction="sum")

                loss.backward()

                grad_vec = _flatten_grads(params)  # float32 vector
                grad_sum.add_(grad_vec)
                count += xb.size(0)

        mean_grad = grad_sum / max(count, 1)
        torch.save(mean_grad.cpu().to(store_dtype),
                   os.path.join(rhs_dir, f"{c}.pt"))

        # No empty_cache() in the hot loop; it slows you down.
        del grad_sum, mean_grad

    np.save(os.path.join(output_dir, "cluster_counts.npy"), cluster_counts)
    print(f"Saved {n_clusters} cluster mean gradients to {rhs_dir}")


def load_cluster_gradient(output_dir: str, cluster_id: int) -> torch.Tensor:
    """Load mean gradient for a specific cluster."""
    rhs_dir = os.path.join(output_dir, "rhs_cluster")
    path = os.path.join(rhs_dir, f"{cluster_id}.pt")
    return torch.load(path, map_location='cpu').float()


def load_cluster_counts(output_dir: str) -> np.ndarray:
    """Load cluster counts."""
    path = os.path.join(output_dir, "cluster_counts.npy")
    return np.load(path)


def check_rhs_completeness(output_dir: str, n_clusters: int) -> bool:
    """Check if all cluster RHS files exist."""
    rhs_dir = os.path.join(output_dir, "rhs_cluster")
    
    if not os.path.exists(rhs_dir):
        return False
    
    for c in range(n_clusters):
        if not os.path.exists(os.path.join(rhs_dir, f"{c}.pt")):
            return False
    
    return True


class RHSManager:
    """
    Manager for cluster mean gradients (RHS vectors).
    
    Provides lazy loading and caching of cluster gradients.
    """
    
    def __init__(
        self,
        output_dir: str,
        n_clusters: int,
        device: str = "cuda",
        cache_size: int = 10,
    ):
        """
        Args:
            output_dir: Directory containing rhs_cluster/
            n_clusters: Number of clusters
            device: Device to load tensors to
            cache_size: Number of RHS vectors to cache
        """
        self.output_dir = output_dir
        self.n_clusters = n_clusters
        self.device = device
        self.cache_size = cache_size
        
        self.rhs_dir = os.path.join(output_dir, "rhs_cluster")
        
        # LRU cache
        self._cache: Dict[int, torch.Tensor] = {}
        self._cache_order: list = []
        
        # Load cluster counts
        counts_path = os.path.join(output_dir, "cluster_counts.npy")
        if os.path.exists(counts_path):
            self.counts = np.load(counts_path)
        else:
            self.counts = None
        
    def get(self, cluster_id: int) -> torch.Tensor:
        """Get RHS vector for a cluster (with caching)."""
        if cluster_id in self._cache:
            # Move to end of LRU order
            self._cache_order.remove(cluster_id)
            self._cache_order.append(cluster_id)
            return self._cache[cluster_id]
        
        # Load from disk
        path = os.path.join(self.rhs_dir, f"{cluster_id}.pt")
        rhs = torch.load(path, map_location='cpu').float().to(self.device)
        
        # Add to cache
        self._cache[cluster_id] = rhs
        self._cache_order.append(cluster_id)
        
        # Evict if over capacity
        while len(self._cache) > self.cache_size:
            oldest = self._cache_order.pop(0)
            del self._cache[oldest]
        
        return rhs
    
    def get_count(self, cluster_id: int) -> int:
        """Get number of samples in a cluster."""
        if self.counts is None:
            return 0
        return int(self.counts[cluster_id])
    
    def clear_cache(self):
        """Clear the RHS cache."""
        self._cache.clear()
        self._cache_order.clear()
    
    def __getitem__(self, cluster_id: int) -> torch.Tensor:
        return self.get(cluster_id)
    
    @property
    def is_complete(self) -> bool:
        """Check if all RHS files exist."""
        return check_rhs_completeness(self.output_dir, self.n_clusters)


def build_rhs_vectors(
    model: ViTWithHooks,
    imagenet_root: str,
    cluster_ids: np.ndarray,
    output_dir: str,
    n_clusters: int,
    batch_size: int = 64,
    num_workers: int = 8,
    device: str = "cuda",
    use_streaming: bool = True,
):
    """
    Build all cluster RHS vectors.
    
    High-level function that handles both regular and streaming modes.
    
    DDP:
        - In DDP mode, only rank 0 runs the computation.
        - All ranks call barrier() after rank 0 finishes.
        - Other ranks simply wait and then can load saved files.
    
    Args:
        model: ViT model
        imagenet_root: Path to ImageNet
        cluster_ids: Cluster assignments
        output_dir: Output directory
        n_clusters: Number of clusters
        batch_size: Batch size
        num_workers: Data loading workers
        device: Computation device
        use_streaming: Use streaming mode (lower memory)
    """
    # ===========================================================================
    # DDP: Only rank 0 builds RHS. Other ranks wait at barrier.
    # ===========================================================================
    if is_ddp() and not is_rank0():
        logger.info("DDP: Rank != 0, waiting at barrier for RHS building to complete")
        barrier()
        return
    
    # Rank 0 (or non-DDP) does the actual work
    # Check if already computed
    if check_rhs_completeness(output_dir, n_clusters):
        rank0_print(f"RHS vectors already computed in {output_dir}/rhs_cluster/")
        if is_ddp():
            barrier()
        return
    
    dataset = ImageNetDataset(imagenet_root, split="train")
    
    if use_streaming:
        accumulate_cluster_gradients_streaming(
            model, dataset, cluster_ids, output_dir, n_clusters,
            batch_size=batch_size,
            num_workers=num_workers,
            device=device,
        )
    else:
        # Use indexed loader
        dataloader = get_indexed_loader(
            imagenet_root,
            split="train",
            batch_size=batch_size,
            num_workers=num_workers,
        )
        
        accumulate_cluster_gradients(
            model, dataloader, cluster_ids, output_dir, n_clusters,
            device=device,
        )
    
    # ===========================================================================
    # DDP: Barrier after rank 0 finishes so other ranks know files are ready
    # ===========================================================================
    if is_ddp():
        logger.info("DDP: Rank 0 finished RHS building, calling barrier")
        barrier()


if __name__ == "__main__":
    # Test RHS manager
    import sys
    
    if len(sys.argv) < 2:
        print("Usage: python rhs_build.py /path/to/output_dir")
        print("\nTesting with mock data...")
        
        # Create mock RHS files
        output_dir = "/tmp/ifc_test"
        rhs_dir = os.path.join(output_dir, "rhs_cluster")
        os.makedirs(rhs_dir, exist_ok=True)
        
        n_clusters = 5
        num_params = 1000
        
        for c in range(n_clusters):
            rhs = torch.randn(num_params)
            torch.save(rhs.half(), os.path.join(rhs_dir, f"{c}.pt"))
        
        np.save(os.path.join(output_dir, "cluster_counts.npy"),
               np.array([100, 200, 150, 80, 170]))
        
        # Test manager
        manager = RHSManager(output_dir, n_clusters, device="cpu", cache_size=3)
        
        print(f"Is complete: {manager.is_complete}")
        
        for c in range(n_clusters):
            rhs = manager[c]
            count = manager.get_count(c)
            print(f"Cluster {c}: shape={rhs.shape}, count={count}")
    else:
        output_dir = sys.argv[1]
        n_clusters = int(sys.argv[2]) if len(sys.argv) > 2 else 500
        
        manager = RHSManager(output_dir, n_clusters)
        print(f"Is complete: {manager.is_complete}")
