"""
Gradient sketching and clustering.

Computes per-sample gradient sketches using random JL projections
and clusters them using minibatch k-means.

DDP Support:
    - In DDP mode, clustering runs only on rank 0.
    - All ranks call barrier() after rank 0 finishes.
    - cluster_ids.npy, centroids.pt, etc. are written by rank 0 only.
    - Other ranks can load the saved files after the barrier.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Tuple, List, List
import numpy as np
import os
from tqdm import tqdm
from torch.func import functional_call, vmap, grad
from sklearn.cluster import KMeans
from .vit_full import ViTWithHooks
from .logging_utils import (
    get_logger, log_tensor_stats, log_dict, StageTimer,
    is_rank0, is_ddp, barrier, rank0_print,
)
logger = get_logger("ifc_vit.cluster")


class JLProjector:
    """
    Johnson-Lindenstrauss random projector.
    
    Projects high-dimensional vectors to lower dimension while
    approximately preserving distances.
    """
    
    def __init__(self, input_dim: int, output_dim: int, seed: int = 42,
                 device: str = "cuda", chunk_size: int = 2_000_000, dtype=torch.float16):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.seed = seed
        self.device = device
        self.chunk_size = chunk_size
        self.dtype = dtype

        g = torch.Generator(device="cpu")
        g.manual_seed(seed)

        # Precompute on CPU, then move once to GPU (or directly make on GPU if you trust determinism there)
        idx = torch.randint(0, output_dim, (input_dim,), generator=g, dtype=torch.int32)
        sgn = torch.randint(0, 2, (input_dim,), generator=g, dtype=torch.int8) * 2 - 1  # {-1, +1}

        # Typical CountSketch scaling (optional; you can omit scaling and let downstream handle)
        scale = (1.0 / (output_dim ** 0.5))
        self.indices = idx.to(device, non_blocking=True)
        self.signs = (sgn.to(device, non_blocking=True).to(torch.float32) * scale)

    def project(self, v: torch.Tensor) -> torch.Tensor:
        """
        v: (P,) or (B,P) on GPU
        returns: (d,) or (B,d)
        """
        single = (v.dim() == 1)
        if single:
            v = v.unsqueeze(0)
        B, P = v.shape
        assert P == self.input_dim

        out = torch.zeros((B, self.output_dim), device=v.device, dtype=torch.float32)

        # Chunk loop to limit peak memory and keep kernels large
        for start in range(0, P, self.chunk_size):
            end = min(start + self.chunk_size, P)
            idx = self.indices[start:end]          # (L,)
            sgn = self.signs[start:end]            # (L,)
            x = v[:, start:end].to(torch.float32)  # (B,L)

            # Avoid recreating expanded indices tensor every time:
            # Use expand() view (no alloc) and rely on scatter_add kernel
            out.scatter_add_(1, idx.view(1, -1).expand(B, -1), x * sgn.view(1, -1))

        if single:
            out = out.squeeze(0)
        return out.to(self.dtype)
    
    @torch.no_grad()
    def add_slice_(self, out: torch.Tensor, x_flat: torch.Tensor, start: int):
        """
        out: (B, d) float32 on GPU
        x_flat: (B, n) float32/float16 on GPU (grad slice)
        start: starting offset in the flattened parameter vector
        """
        B, n = x_flat.shape
        idx = self.indices[start:start+n]
        sgn = self.signs[start:start+n]
        out.scatter_add_(1, idx.view(1, -1).expand(B, -1), x_flat.to(torch.float32) * sgn.view(1, -1))

    def __call__(self, v: torch.Tensor) -> torch.Tensor:
        return self.project(v)
    
    def save(self, path: str):
        """Save projector state."""
        state = {
            'input_dim': self.input_dim,
            'output_dim': self.output_dim,
            'seed': self.seed,
            'indices': self.indices.cpu(),
            'signs': self.signs.cpu(),
            'chunk_size': self.chunk_size,

        }
        torch.save(state, path)
        
    @classmethod
    def load(cls, path: str, device: str = "cuda") -> "JLProjector":
        """Load projector from file."""
        state = torch.load(path, map_location='cpu')
        projector = cls.__new__(cls)
        projector.input_dim = state['input_dim']
        projector.output_dim = state['output_dim']
        projector.seed = state['seed']
        projector.device = device
        projector.indices = state['indices'].to(device, non_blocking=True)
        projector.signs = state['signs'].to(device, non_blocking=True)
        projector.chunk_size = state['chunk_size']
        
        return projector

@torch.no_grad()
def flatten_params(params):
    return torch.cat([p.view(-1) for p in params.values()])

def compute_gradient_sketches(
    vitwrap,
    dataloader,
    sketch_dim: int = 1024,
    output_path: str = "train_sketch.dat",
    seed: int = 42,
    device: str = "cuda",
    out_dtype=np.float32,
):
    model = vitwrap.model.eval().to(device)
    vitwrap.enable_full_grads(True)

    # Prepare functional model
    params = dict(model.named_parameters())
    param_shapes = {k: p.shape for k, p in params.items()}
    flat_params = flatten_params(params)
    total_dim = flat_params.numel()

    # Random projection matrix
    torch.manual_seed(seed)
    R = torch.randn(total_dim, sketch_dim, device=device, dtype=torch.float32) / np.sqrt(sketch_dim)

    N = len(dataloader.dataset)
    mm = np.memmap(output_path, dtype=out_dtype, mode="w+", shape=(N, sketch_dim))
    np.save(output_path + ".meta.npy", {'shape': (N, sketch_dim), 'dtype': np.dtype(out_dtype).name})

    # Define loss function
    def loss_fn(params, x, y):
        logits = functional_call(model, params, (x.unsqueeze(0),))
        return F.cross_entropy(logits, y.unsqueeze(0))

    # Vectorized gradient function
    grad_fn = grad(loss_fn)
    per_sample_grad = vmap(grad_fn, in_dims=(None, 0, 0))

    offset = 0
    for images, labels in tqdm(dataloader, desc="Gradient sketches"):
        xb, yb = images.to(device), labels.to(device)

        grads_dict = per_sample_grad(params, xb, yb)
        grads_flat = torch.cat([g.contiguous().view(xb.size(0), -1) for g in grads_dict.values()], dim=1)

        # Project and move to CPU in batch
        sketch = grads_flat @ R
        mm[offset:offset + xb.size(0)] = sketch.detach().cpu().numpy().astype(out_dtype, copy=False)
        offset += xb.size(0)

        del grads_dict, grads_flat, sketch

    mm.flush()
    return {"shape": (N, sketch_dim), "path": output_path}

class KMeans:
    def __init__(
        self,
        n_clusters: int,
        max_iter: int = 100,
        tol: float = 1e-4,
        seed: int = 42,
        device: str = "cuda",
        batch_size: int = 50_000,  # For memory-efficient distance computation
    ):
        """
        Args:
            n_clusters: Number of clusters
            max_iter: Maximum number of iterations
            tol: Convergence tolerance (fraction of samples that change)
            seed: Random seed
            device: Computation device
            batch_size: Batch size for distance computation (memory management)
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.tol = tol
        self.seed = seed
        self.device = device
        self.batch_size = batch_size
        
        self.centroids: Optional[torch.Tensor] = None
        
    def fit(
        self,
        sketches_path: str,
        verbose: bool = True,
    ) -> np.ndarray:
        """
        
        Args:
            sketches_path: Path to gradient sketches .npy file
            verbose: Print progress
            
        Returns:
            cluster_ids: (N,) array of cluster assignments
        """
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        
        # Load sketches memmap (raw binary file + metadata)
        meta_path = sketches_path + ".meta.npy"
        if os.path.exists(meta_path):
            meta = np.load(meta_path, allow_pickle=True).item()
            shape = meta['shape']
            dtype = np.dtype(meta['dtype'])
            sketches = np.memmap(sketches_path, dtype=dtype, mode='r', shape=shape)
        else:
            # Fallback for .npy files (backwards compatibility)
            sketches = np.load(sketches_path, mmap_mode='r', allow_pickle=True)
        n_samples, sketch_dim = sketches.shape
        
        logger.info(f"Fitting K-Means: {n_samples:,} samples, {self.n_clusters} clusters")
        log_dict(logger, "K-Means configuration", {
            'n_clusters': self.n_clusters,
            'max_iter': self.max_iter,
            'tol': self.tol,
            'n_samples': n_samples,
            'sketch_dim': sketch_dim,
            'distance_metric': 'cosine',
        })
        
        # Load all sketches to GPU
        logger.info("Loading sketches...")
        sketches_original = torch.from_numpy(sketches[:].astype(np.float32)).to(self.device)
        
        # Normalize for cosine distance clustering
        logger.info("Normalizing sketches for cosine distance...")
        sketches_normalized = F.normalize(sketches_original, p=2, dim=1)
        # Initialize centroids using k-means++ with cosine distance
        # self._init_centroids_pp(feat, verbose)
        rng = np.random.default_rng(int(self.seed))
        est = KMeans(n_clusters=self.n_clusters, init='k-means++', n_init=1, max_iter=250, random_state=int(self.seed), verbose=1).fit(sketches_original.cpu().numpy())
        self.centroids = torch.from_numpy(est.cluster_centers_).to(self.device, non_blocking=True).to(torch.float32)
        cluster_ids = est.labels_
        inertias = est.inertia_
        self._log_final_statistics(sketches_original=sketches_original, sketches_normalized=sketches_normalized, cluster_ids=cluster_ids, n_clusters=self.n_clusters, inertia=inertias,logger=logger)
      
        
        return cluster_ids
    
    
    def _assign_all_batched(self, sketches: torch.Tensor) -> torch.Tensor:
        """Assign all samples to nearest centroid using batched computation."""
        n_samples = sketches.shape[0]
        assignments = torch.zeros(n_samples, dtype=torch.long, device=self.device)
        
        for start in range(0, n_samples, self.batch_size):
            end = min(start + self.batch_size, n_samples)
            batch = sketches[start:end]
            
            # Cosine similarity (dot product of normalized vectors)
            similarities = batch @ self.centroids.T  # (batch, n_clusters)
            
            # Assign to maximum similarity (minimum cosine distance)
            assignments[start:end] = similarities.argmax(dim=1)
        
        return assignments
    


    
    def _log_iteration_stats(self, assignments: torch.Tensor, iteration: int, inertia: float):
        """Log statistics for current iteration."""
        unique, counts = torch.unique(assignments, return_counts=True)
        counts_np = counts.cpu().numpy()
        
        log_dict(logger, f"Iteration {iteration} cluster statistics", {
            'inertia': inertia,
            'n_active_clusters': len(unique),
            'empty_clusters': self.n_clusters - len(unique),
            'min_cluster_size': int(counts_np.min()),
            'max_cluster_size': int(counts_np.max()),
            'mean_cluster_size': float(counts_np.mean()),
            'std_cluster_size': float(counts_np.std()),
        })
    
    def _log_final_statistics(self,
            sketches_original: torch.Tensor,      # (N, d) in original JL space
            sketches_normalized: torch.Tensor,    # (N, d) L2-normalized (for cosine stats)
            cluster_ids: np.ndarray,              # (N,) int in [0, K-1]
            n_clusters: int,                      # K
            inertia: float,                       # sum_i ||x_i - mu_{c(i)}||^2 (should match sklearn)
            logger,
        ):

        assert sketches_original.ndim == 2
        assert sketches_normalized.ndim == 2
        N, d = sketches_original.shape
        assert sketches_normalized.shape == (N, d)

        device = sketches_original.device
        dtype = sketches_original.dtype

        # ---- cluster ids tensor ----
        cid = torch.from_numpy(cluster_ids).to(device=device, dtype=torch.long)
        if cid.min().item() < 0 or cid.max().item() >= n_clusters:
            raise ValueError(f"cluster_ids out of range [0, {n_clusters-1}]")

        # ---- counts ----
        counts = torch.bincount(cid, minlength=n_clusters)  # (K,)
        nonempty = counts > 0
        n_nonempty = int(nonempty.sum().item())
        empty_clusters = int((~nonempty).sum().item())

        # ---- centroids in ORIGINAL JL space: mu_c = mean_{i in c} x_i ----
        # sums: (K, d)
        sums = torch.zeros((n_clusters, d), device=device, dtype=dtype)
        sums.scatter_add_(0, cid[:, None].expand(N, d), sketches_original)
        centroids_orig = sums / counts.clamp_min(1)[:, None]  # safe for empty clusters

        # ---- per-point squared distance to assigned centroid ----
        assigned_centroids = centroids_orig[cid]  # (N, d)
        diffs = sketches_original - assigned_centroids
        sq_dists = (diffs * diffs).sum(dim=1)  # (N,)

        # Inertia should equal sum(sq_dists). We trust your passed-in value,
        # but we can recompute to be consistent:
        inertia_torch = float(sq_dists.sum().item())

        # ---- THE quantity used by your bound: tr(Sigma_bar^Pi) ----
        # tr(Sigma_bar^Pi) = (1/N) sum_i ||x_i - mu_{c(i)}||^2 = inertia / N
        tr_sigma_bar_pi = inertia_torch / float(N)
        rms_scatter_global = float(np.sqrt(tr_sigma_bar_pi))

        # ---- within-cluster mean squared distance (per cluster) ----
        # var_c = (1/n_c) sum_{i in c} ||x_i - mu_c||^2
        within_cluster_sse = torch.zeros((n_clusters,), device=device, dtype=dtype)
        within_cluster_sse.scatter_add_(0, cid, sq_dists)
        within_cluster_var = within_cluster_sse / counts.clamp_min(1)  # (K,)

        # ---- within-cluster mean distance (not used in theory; diagnostic) ----
        dists = torch.sqrt(sq_dists.clamp_min(0))
        within_cluster_sum_dist = torch.zeros((n_clusters,), device=device, dtype=dtype)
        within_cluster_sum_dist.scatter_add_(0, cid, dists)
        within_cluster_mean_dist = within_cluster_sum_dist / counts.clamp_min(1)

        # ---- cosine similarity: per-point sim to cluster centroid (in normalized space) ----
        # Use centroid direction from ORIGINAL centroids, normalized.
        centroids_norm = F.normalize(centroids_orig, p=2, dim=1)  # (K, d), safe for empty
        sims = (sketches_normalized * centroids_norm[cid]).sum(dim=1)  # (N,)

        within_cluster_sum_sim = torch.zeros((n_clusters,), device=device, dtype=dtype)
        within_cluster_sum_sim.scatter_add_(0, cid, sims)
        within_cluster_mean_sim = within_cluster_sum_sim / counts.clamp_min(1)

        # ---- global mean + total variance in ORIGINAL space ----
        global_mean = sketches_original.mean(dim=0)  # (d,)
        total_var = ((sketches_original - global_mean) ** 2).sum(dim=1).mean().item()  # mean squared norm

        # ---- between-cluster variance (point-weighted) ----
        # between = sum_c (n_c/N) * ||mu_c - global_mean||^2
        centroid_sq = ((centroids_orig - global_mean) ** 2).sum(dim=1)  # (K,)
        between_cluster_var = ((counts.to(dtype) / float(N)) * centroid_sq).sum().item()

        # ---- relationships ----
        # For Euclidean k-means with centroids as means:
        # total_var ≈ between + within (within = inertia/N)
        within_global = tr_sigma_bar_pi
        variance_explained = between_cluster_var / total_var if total_var > 0 else float("nan")
        variance_ratio_between_within = between_cluster_var / within_global if within_global > 0 else float("nan")

        # ---- cluster size stats ----
        counts_np = counts.detach().cpu().numpy()
        nonempty_counts_np = counts[nonempty].detach().cpu().numpy()

        # ---- per-cluster stats (only on nonempty clusters) ----
        within_cluster_var_np = within_cluster_var[nonempty].detach().cpu().numpy()
        within_cluster_mean_dist_np = within_cluster_mean_dist[nonempty].detach().cpu().numpy()
        within_cluster_mean_sim_np = within_cluster_mean_sim[nonempty].detach().cpu().numpy()

        # ---- between-cluster cosine similarity (centroids vs global centroid direction) ----
        global_centroid_norm = F.normalize(centroids_norm[nonempty].mean(dim=0, keepdim=True), p=2, dim=1)  # (1,d)
        between_cluster_sims = (centroids_norm[nonempty] * global_centroid_norm).sum(dim=1)  # (K_nonempty,)
        between_cluster_cos_mean = float(between_cluster_sims.mean().item())
        between_cluster_cos_std = float(between_cluster_sims.std(unbiased=False).item()) if n_nonempty > 1 else 0.0

        # ---- log ----
        log_dict = {
            "n_clusters": int(n_clusters),
            "nonempty_clusters": int(n_nonempty),
            "empty_clusters": int(empty_clusters),

            "cluster_size_min": int(nonempty_counts_np.min()) if n_nonempty else 0,
            "cluster_size_max": int(nonempty_counts_np.max()) if n_nonempty else 0,
            "cluster_size_mean": float(nonempty_counts_np.mean()) if n_nonempty else 0.0,
            "cluster_size_std": float(nonempty_counts_np.std()) if n_nonempty else 0.0,

            # THEORY-ALIGNED core quantities
            "final_inertia_passed": float(inertia),
            "final_inertia_recomputed": float(inertia_torch),
            "trace_sigma_bar_Pi": float(tr_sigma_bar_pi),                 # = inertia/N
            "rms_scatter_global": float(rms_scatter_global),              # = sqrt(inertia/N)

            # Within-cluster (per cluster) variance in ORIGINAL space (means of per-cluster means)
            # Note: these are *cluster-weighted*, not point-weighted. Keep as diagnostics.
            "within_cluster_variance_mean": float(within_cluster_var_np.mean()) if n_nonempty else None,
            "within_cluster_variance_std": float(within_cluster_var_np.std()) if n_nonempty else None,

            # Within-cluster mean distance (diagnostic)
            "within_cluster_scatter_mean": float(within_cluster_mean_dist_np.mean()) if n_nonempty else None,
            "within_cluster_scatter_median": float(np.median(within_cluster_mean_dist_np)) if n_nonempty else None,
            "within_cluster_scatter_min": float(within_cluster_mean_dist_np.min()) if n_nonempty else None,
            "within_cluster_scatter_max": float(within_cluster_mean_dist_np.max()) if n_nonempty else None,

            # Cosine similarity (normalized space)
            "within_cluster_cosine_sim_mean": float(within_cluster_mean_sim_np.mean()) if n_nonempty else None,
            "within_cluster_cosine_sim_std": float(within_cluster_mean_sim_np.std()) if n_nonempty else None,
            "between_cluster_cosine_sim_mean": between_cluster_cos_mean,
            "between_cluster_cosine_sim_std": between_cluster_cos_std,

            # Global variance decomposition in ORIGINAL JL space (point-weighted)
            "total_variance": float(total_var),
            "within_global_variance": float(within_global),               # = inertia/N
            "between_cluster_variance": float(between_cluster_var),
            "variance_explained": float(variance_explained),
            "variance_ratio_between_within": float(variance_ratio_between_within),
            "total_minus_between_minus_within": float(total_var - between_cluster_var - within_global),
        }

        # Replace this with your logger call
        logger.info({"msg": "Final clustering statistics (all clusters, theory-aligned)", **log_dict})
        return log_dict
    
    def save(self, path: str):
        """Save centroids and state."""
        torch.save({
            'centroids': self.centroids.cpu(),
            'n_clusters': self.n_clusters,
        }, path)
        
    def load(self, path: str):
        """Load centroids and state."""
        state = torch.load(path, map_location='cpu')
        self.centroids = state['centroids'].to(self.device)
        self.n_clusters = state['n_clusters']


def cluster_gradients(
    model: ViTWithHooks,
    dataloader: DataLoader,
    output_dir: str,
    n_clusters: int = 500,
    sketch_dim: int = 1024,
    max_iter: int = 100,
    tol: float = 1e-4,
    device: str = "cuda",
) -> Tuple[np.ndarray, np.ndarray, JLProjector]:
    """
    Full pipeline: compute gradient sketches and cluster them with cosine distance.
    
    DDP:
        - In DDP mode, only rank 0 runs the computation.
        - All ranks call barrier() after rank 0 finishes.
        - Other ranks load the saved files after the barrier.
    
    Args:
        model: ViT model
        dataloader: Training data loader
        output_dir: Output directory for all files
        n_clusters: Number of clusters
        sketch_dim: Gradient sketch dimension
        max_iter: Maximum K-means iterations
        tol: Convergence tolerance
        device: Computation device
        
    Returns:
        cluster_ids: (N,) cluster assignments
        centroids: (C, sketch_dim) cluster centroids
        projector: JL projector used
    """
    sketches_path = os.path.join(output_dir, "train_sketch.dat")  # raw memmap file
    cluster_ids_path = os.path.join(output_dir, "cluster_id.npy")
    centroids_path = os.path.join(output_dir, "centroids.pt")
    projector_path = os.path.join(output_dir, "jl_projector.pt")
    
    # ===========================================================================
    # DDP: Only rank 0 runs clustering. Other ranks wait at barrier, then load.
    # ===========================================================================
    if is_ddp() and not is_rank0():
        logger.info("DDP: Rank != 0, waiting at barrier for clustering to complete")
        barrier()
        
        # Load results saved by rank 0
        cluster_ids = np.load(cluster_ids_path)
        kmeans = KMeans(n_clusters, max_iter=max_iter, tol=tol, device=device)
        kmeans.load(centroids_path)
        projector = JLProjector.load(projector_path, device)
        
        return cluster_ids, kmeans.centroids.cpu().numpy(), projector
    
    # Rank 0 (or non-DDP) does the actual work
    os.makedirs(output_dir, exist_ok=True)
    sketches_meta_path = sketches_path + ".meta.npy"
    
    # Step 1: Compute gradient sketches
    if os.path.exists(sketches_path) and os.path.exists(sketches_meta_path):
        rank0_print(f"Loading cached sketches from {sketches_path}")
        if os.path.exists(projector_path):
            projector = JLProjector.load(projector_path, device)
        else:
            projector = JLProjector(model.num_params, sketch_dim, device=device)
    else:
        projector = JLProjector(model.num_params, sketch_dim, device=device)
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
        compute_gradient_sketches(
            model, dataloader, sketch_dim,
            output_path=sketches_path,
            device=device,
        )
        projector.save(projector_path)
    
    # Step 2: Cluster sketches using Cosine K-Means
    if os.path.exists(cluster_ids_path):
        rank0_print(f"Loading cached cluster IDs from {cluster_ids_path}")
        cluster_ids = np.load(cluster_ids_path)
        kmeans = KMeans(n_clusters, max_iter=max_iter, tol=tol, device=device)
        if os.path.exists(centroids_path):
            kmeans.load(centroids_path)
    else:
        kmeans = KMeans(n_clusters, max_iter=max_iter, tol=tol, device=device)
        cluster_ids = kmeans.fit(sketches_path)
        
        np.save(cluster_ids_path, cluster_ids)
        kmeans.save(centroids_path)
    
    # Print cluster statistics
    unique, counts = np.unique(cluster_ids, return_counts=True)
    rank0_print(f"\nCluster statistics:")
    rank0_print(f"  Number of clusters: {len(unique)}")
    rank0_print(f"  Min cluster size: {counts.min()}")
    rank0_print(f"  Max cluster size: {counts.max()}")
    rank0_print(f"  Mean cluster size: {counts.mean():.1f}")
    
    # ===========================================================================
    # DDP: Barrier after rank 0 finishes so other ranks can load files
    # ===========================================================================
    if is_ddp():
        logger.info("DDP: Rank 0 finished clustering, calling barrier")
        barrier()
    
    return cluster_ids, kmeans.centroids.cpu().numpy(), projector


def get_cluster_indices(cluster_ids: np.ndarray, cluster_id: int) -> np.ndarray:
    """Get indices of samples belonging to a specific cluster."""
    return np.where(cluster_ids == cluster_id)[0]


def get_cluster_sizes(cluster_ids: np.ndarray) -> np.ndarray:
    """Get size of each cluster."""
    unique, counts = np.unique(cluster_ids, return_counts=True)
    sizes = np.zeros(cluster_ids.max() + 1, dtype=np.int32)
    sizes[unique] = counts
    return sizes
