from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import logging

import time


import torch
from torch import nn
from torch.utils.data import DataLoader

from ..utils.seed import set_seed
from .helpers import BaseObjective, CGInfluenceModule, make_jl_matrix, project, kmeans_cluster
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np

@dataclass
class IFCompressedCfg:
    name: str
    jl_dim: int
    clusters: int
    damping: float
    max_cg_iters: int
    tol: float
    cluster_method: str = "kmeans"
    fisher: bool = True
    normalize: bool = False
    line_search: bool = False
    recourse_steps: int = 0
    cache_grads: bool = True
    seed: int = 0
    # UMAP controls
    use_umap: bool = False
    umap_n_components: int = 25
    umap_n_neighbors: int = 15
    umap_min_dist: float = 0.1
    umap_metric: str = 'cosine'
    # Misc
    random_state: int = 0
    n_jobs: int = 1
    collect_diagnostics: bool = True


class _SimpleObjective(BaseObjective):
    def __init__(self, loss_fn: nn.Module):
        self.loss_fn = loss_fn

    def train_outputs(self, model: nn.Module, batch):
        x, y = batch
        return model(x)

    def train_loss_on_outputs(self, outputs: torch.Tensor, batch):
        _, y = batch
        return self.loss_fn(outputs, y)

    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        return torch.tensor(0.0, device=params.device)

    def test_loss(self, model: nn.Module, params: torch.Tensor, batch):
        x, y = batch
        return self.loss_fn(model(x), y)


def _per_sample_grads_stream(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    P: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Compute per‑sample gradients of the loss w.r.t. model parameters.

    Returns a tensor of shape (N, P) where P is the total number of
    parameters.  Each row contains the gradient of the cross‑entropy loss
    for that example.  Computed using autograd.
    """
    model.eval()
    device = next(model.parameters()).device
    grads = []

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        for i in range(x.size(0)):
            model.zero_grad()
            output = model(x[i : i + 1])
            loss = loss_fn(output, y[i : i + 1])
            loss.backward()
            grad_vec = torch.cat([
                p.grad.detach().reshape(-1) if p.grad is not None else torch.zeros_like(p).reshape(-1)
                for p in model.parameters()
            ])
            if P is not None:
                grad_vec = P @ grad_vec
            grads.append(grad_vec.detach().clone())

    grads_tensor = torch.stack(grads)
    return grads_tensor


def _iter_sample_grads(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    params: Optional[List[torch.nn.Parameter]] = None,
):
    """Yield flattened per-sample gradients in loader order."""
    device = next(model.parameters()).device
    if params is None:
        params = [p for p in model.parameters() if p.requires_grad]
    if not params:
        raise ValueError("Model has no trainable parameters")

    model.eval()
    pin = bool(getattr(loader, "pin_memory", False))
    with torch.enable_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device, non_blocking=pin), yb.to(device, non_blocking=pin)
            batch = xb.shape[0]
            for i in range(batch):
                model.zero_grad(set_to_none=True)
                out = model(xb[i : i + 1])
                loss = loss_fn(out, yb[i : i + 1])
                grads = torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True)
                flat = torch.cat([
                    (g if g is not None else torch.zeros_like(p)).reshape(-1)
                    for g, p in zip(grads, params)
                ])
                yield flat.detach()
    model.zero_grad(set_to_none=True)


def _jl_project_vector_streaming(
    vec: torch.Tensor,
    proj_dim: int,
    seed: int,
    block_in: int,
    device: Union[str, torch.device],
) -> torch.Tensor:
    """
    Streaming Gaussian JL projection for a single vector.

    Implements y = P x where P ∈ R^{proj_dim × d_in},
    P_ij ~ N(0, 1/proj_dim), without materializing P.
    """
    dev = torch.device(device)
    v = vec.detach().view(-1)
    d_in = v.numel()

    if v.device != dev:
        v = v.to(dev, non_blocking=True)

    y = torch.zeros(proj_dim, device=dev, dtype=v.dtype)

    gen = torch.Generator(device=dev)
    gen.manual_seed(seed)

    start = 0
    while start < d_in:
        end = min(start + block_in, d_in)
        v_block = v[start:end]
        L = v_block.numel()

        # Columns of P for indices [start:end]
        R_block = torch.randn(
            proj_dim,
            L,
            generator=gen,
            device=dev,
            dtype=v.dtype,
        ) / (proj_dim ** 0.5)

        # y += R_block @ v_block
        y += R_block.matmul(v_block)

        start = end

    return y

def _stream_projected_grad_matrix(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    params: List[torch.nn.Parameter],
    proj_dim: int,
    jl_seed: Optional[int],
    jl_block_in: int,
    device: Union[str, torch.device],
    dtype: torch.dtype = torch.float32,
):
    """
    Return (optionally JL-projected) per-sample gradients without forming the full
    (n_samples × param_dim) matrix or a dense JL matrix.

    If jl_seed is None or proj_dim == param_dim:
        -> no projection, just flatten per-sample grads.

    Else:
        -> apply streaming Gaussian JL with proj_dim.
    """
    param_dim = sum(p.numel() for p in params)
    use_jl = (jl_seed is not None) and (proj_dim < param_dim)

    dataset = getattr(loader, "dataset", None)
    total = len(dataset) if dataset is not None else None

    if total is not None:
        store = torch.empty((total, proj_dim), dtype=dtype)
        rows: List[torch.Tensor] = []
    else:
        store = None
        rows = []

    idx = 0
    start_time = time.time()

    for grad_vec in _iter_sample_grads(model, loader, loss_fn, params=params):
        if use_jl:
            vec = _jl_project_vector_streaming(
                grad_vec,
                proj_dim=proj_dim,
                seed=jl_seed,
                block_in=jl_block_in,
                device=device,
            )
        else:
            # no projection, just flatten
            vec = grad_vec.view(-1)

        vec_cpu = vec.detach().to("cpu", dtype=dtype)

        if store is None:
            rows.append(vec_cpu)
        else:
            if idx >= store.shape[0]:
                raise RuntimeError(
                    "DataLoader produced more samples than reported by its dataset length"
                )
            store[idx] = vec_cpu
        idx += 1

    elapsed = time.time() - start_time

    if store is None:
        mat = torch.stack(rows) if rows else torch.empty((0, proj_dim), dtype=dtype)
    else:
        mat = store[:idx]

    return mat, float(elapsed)



def _compute_cluster_means(
    model: nn.Module,
    loader: DataLoader,
    loss_fn: nn.Module,
    params: List[torch.nn.Parameter],
    labels: np.ndarray,
    num_clusters: int,
    device: torch.device,
):
    P = sum(p.numel() for p in params)
    sums = [torch.zeros(P, device=device) for _ in range(num_clusters)]
    counts = np.zeros(num_clusters, dtype=np.int64)
    idx = 0
    start = time.time()
    for grad_vec in _iter_sample_grads(model, loader, loss_fn, params=params):
        if idx >= len(labels):
            break
        lbl = int(labels[idx])
        if 0 <= lbl < num_clusters:
            sums[lbl] += grad_vec.detach()
            counts[lbl] += 1
        idx += 1
    mus = []
    for c in range(num_clusters):
        if counts[c] == 0:
            mus.append(torch.zeros(P, device=device))
        else:
            mus.append(sums[c] / counts[c])
    return mus, float(time.time() - start)


def _relative_distortion(full: np.ndarray, proj: np.ndarray, *, num_pairs: int = 2000, rng: Optional[np.random.Generator] = None) -> Tuple[float, float]:
    n = full.shape[0]
    if n < 2:
        return float("nan"), float("nan")
    if rng is None:
        rng = np.random.default_rng()
    idx1 = rng.integers(0, n, size=num_pairs)
    idx2 = rng.integers(0, n, size=num_pairs)
    mask = idx1 != idx2
    if not np.any(mask):
        return float("nan"), float("nan")
    idx1, idx2 = idx1[mask], idx2[mask]
    full_d = np.linalg.norm(full[idx1] - full[idx2], axis=1)
    proj_d = np.linalg.norm(proj[idx1] - proj[idx2], axis=1)
    valid = full_d > 1e-12
    if not np.any(valid):
        return float("nan"), float("nan")
    rel = np.abs(full_d[valid] - proj_d[valid]) / full_d[valid]
    if rel.size == 0:
        return float("nan"), float("nan")
    return float(np.mean(rel)), float(np.quantile(rel, 0.95))


def _between_within_ratio(centroids: np.ndarray, labels: np.ndarray, data: np.ndarray) -> float:
    counts = np.bincount(labels, minlength=centroids.shape[0])
    if centroids.shape[0] <= 1 or np.any(counts == 0):
        return float("nan")
    # between-cluster distances: average pairwise centroid distance
    diffs = centroids[:, None, :] - centroids[None, :, :]
    dists = np.linalg.norm(diffs, axis=-1)
    iu = np.triu_indices_from(dists, k=1)
    between = np.mean(dists[iu]) if iu[0].size else float("nan")
    # within-cluster distances: average distance to centroid per cluster
    within_vals = []
    for c in range(centroids.shape[0]):
        idx = np.where(labels == c)[0]
        if idx.size == 0:
            continue
        d = np.linalg.norm(data[idx] - centroids[c], axis=1)
        within_vals.append(np.mean(d))
    within = float(np.mean(within_vals)) if within_vals else float("nan")
    if not np.isfinite(between) or not np.isfinite(within) or within <= 1e-12:
        return float("nan")
    return between / within


def _coeff_variation(counts: np.ndarray) -> float:
    counts = np.asarray(counts, dtype=float)
    mean = np.mean(counts)
    std = np.std(counts)
    if mean <= 1e-12:
        return float("nan")
    return float(std / mean)



def _kmedian_assign(data: np.ndarray, n_clusters: int, rng: np.random.Generator, max_iter: int = 50) -> np.ndarray:
    n = data.shape[0]
    if n_clusters <= 1 or n <= 1:
        return np.zeros(n, dtype=np.int64)
    n_clusters = min(n_clusters, n)
    medoids = data[rng.choice(n, size=n_clusters, replace=False)].copy()
    labels = np.zeros(n, dtype=np.int64)
    for _ in range(max_iter):
        dists = np.sum(np.abs(data[:, None, :] - medoids[None, :, :]), axis=2)
        new_labels = np.argmin(dists, axis=1)
        if np.array_equal(new_labels, labels):
            break
        labels = new_labels
        for c in range(n_clusters):
            idx = np.where(labels == c)[0]
            if idx.size == 0:
                continue
            cluster_pts = data[idx]
            med = np.median(cluster_pts, axis=0)
            d = np.sum(np.abs(cluster_pts - med), axis=1)
            medoids[c] = cluster_pts[np.argmin(d)]
    return labels


def _assign_clusters(
    embeddings: np.ndarray,
    method: str,
    n_clusters: int,
    rng: np.random.Generator,
    seed: int,
):
    n = embeddings.shape[0]
    method = (method or "kmeans").lower()
    if method in {"none", "no", "single"}:
        return np.zeros(n, dtype=np.int64)
    if method == "random":
        return rng.integers(0, max(1, n_clusters), size=n)
    if method in {"kmedian", "k-median", "kmedoids"}:
        return _kmedian_assign(embeddings, max(1, n_clusters), rng)
    # default back to KMeans
    kmeans = KMeans(n_clusters=max(1, n_clusters), random_state=seed)
    return kmeans.fit_predict(embeddings)


def _compute_centroids(embeddings: np.ndarray, labels: np.ndarray, n_clusters: int) -> np.ndarray:
    centroids = []
    dim = embeddings.shape[1] if embeddings.size else 0
    zero = np.zeros(dim, dtype=embeddings.dtype)
    for c in range(n_clusters):
        mask = labels == c
        if mask.any():
            centroids.append(embeddings[mask].mean(axis=0))
        else:
            centroids.append(zero.copy())
    return np.asarray(centroids)




def build_ifc_cache(
    model: nn.Module,
    data_all: DataLoader,
    loss_fn: nn.Module,
    cfg: IFCompressedCfg,
    logger: Optional[logging.Logger] = None,
    full_grads: Optional[torch.Tensor] = None,
    distortion_pairs: int = 2000,
    Gk_raw: Optional[torch.Tensor] = None,
):
    """Build reusable IF-Compressed cache once.

    Returns a dict with cluster labels, full-dim per-cluster IHVP vectors v_c,
    and diagnostics. Use delta_ifc_from_cache(cache, indices_S) to get per-fold deltas.
    """
    cfg.use_umap = False
    device = next(model.parameters()).device
    set_seed(cfg.seed)
    start = time.time()
    model.eval()

    if logger:
        logger.info(f"Building IFC cache: JL_dim={cfg.jl_dim}, clusters={cfg.clusters}")

    if full_grads is not None and logger:
        logger.warning("full_grads argument is ignored; gradients are streamed on demand.")

    params = [p for p in model.parameters() if p.requires_grad]
    if not params:
        raise ValueError("Model has no trainable parameters for IFC cache")

    P = sum(p.numel() for p in params)
    proj_dim = cfg.jl_dim if cfg.jl_dim > 0 else P
    if Gk_raw is None: 
        Gk_raw, t_grads = _stream_projected_grad_matrix(
            model=model,
            loader=data_all,
            loss_fn=loss_fn,
            params=params,
            proj_dim=proj_dim,
            jl_seed=cfg.seed if cfg.jl_dim > 0 else None,  # None → no projection
            jl_block_in=getattr(cfg, "jl_block_in", 1_000_000),
            device=device,
            dtype=torch.float32,
        )
    Gk_raw = Gk_raw.detach().cpu()

    N = int(Gk_raw.shape[0])
    t_grads_full = 0.0
    Gk = torch.nn.functional.normalize(Gk_raw, p=2, dim=1) if cfg.normalize else Gk_raw

    if logger:
        logger.info(f"Computed gradients: N={N}, P={P}, JL_dim={Gk.shape[1]}")

    collect_diag = bool(getattr(cfg, "collect_diagnostics", False))
    jl_rel_mean = float("nan")
    jl_rel_p95 = float("nan")
    
    Gk_reduced = Gk.detach().cpu().numpy()
    t_umap = 0.0
    proj_dist_mean, proj_dist_p95 = float("nan"), float("nan")
    trust = float("nan")

    # 4) Clustering on Gk_reduced
    cluster_method = str(getattr(cfg, "cluster_method", "kmeans")).lower()
    num_clusters = max(1, int(cfg.clusters)) if cluster_method not in {"none", "no", "single"} else 1
    rng = np.random.default_rng(int(cfg.seed))
    t_km0 = time.time()
    if cluster_method == "kmeans":
        kmeans = KMeans(n_clusters=num_clusters, random_state=int(cfg.seed))
        labels_np = kmeans.fit_predict(Gk_reduced)
        centroids = kmeans.cluster_centers_
    else:
        labels_np = _assign_clusters(Gk_reduced, cluster_method, num_clusters, rng, int(cfg.seed))
        centroids = _compute_centroids(Gk_reduced, labels_np, num_clusters)
    t_kmeans = time.time() - t_km0
    counts_arr = np.bincount(labels_np, minlength=num_clusters)
    counts = counts_arr.tolist()

    if logger:
        logger.info("Clustering (%s) done. Cluster sizes: %s", cluster_method, counts)

    t_reps = 0.0
    t_grep = 0.0

    # 5) Cluster means in FULL dim via second streaming pass
    mus, t_cluster_means = _compute_cluster_means(
        model,
        data_all,
        loss_fn,
        params,
        labels_np,
        num_clusters,
        device,
    )

    within_vars = []
    for c in range(num_clusters):
        idx_c = np.where(labels_np == c)[0]
        if len(idx_c) == 0:
            within_vars.append(float("nan"))
            continue
        cg = torch.as_tensor(Gk_reduced[idx_c], device=device, dtype=torch.float32)
        cm = cg.mean(dim=0)
        within_vars.append(torch.mean(torch.sum((cg - cm) ** 2, dim=1)).item())
    logger.info(f"Computed within-cluster variances: {within_vars}")
    silhouette = float("nan")
    bwr = float("nan")
    counts_cv = float("nan")
    try:
        silhouette = float(silhouette_score(Gk_reduced, labels_np)) if len(np.unique(labels_np)) > 1 else float("nan")
    except Exception:
        silhouette = float("nan")
    bwr = _between_within_ratio(centroids, labels_np, Gk_reduced)
    counts_cv = _coeff_variation(counts)

    # 6) IHVP per cluster via CG torch_influence
    if logger:
        logger.info("Computing IHVP for each cluster...")
        
    objective = _SimpleObjective(loss_fn)
    mod = CGInfluenceModule(
        model=model,
        objective=objective,
        train_loader=data_all,
        test_loader=data_all,
        device=device,
        damp=cfg.damping,
        gnh=cfg.fisher,
        maxiter=cfg.max_cg_iters,
        tol=cfg.tol,
    )
    v_c: List[torch.Tensor] = []
    solver_stats = {"iters": []}
    if collect_diag:
        solver_stats.update({
            "residual": [],
            "residual_history": [],
            "stability": [],
            "verification": [],
        })
    t_cg0 = time.time()

    for c, mu in enumerate(mus):
        if counts[c] == 0:
            v_c.append(torch.zeros(P, device=device))
            solver_stats["iters"].append(0)
            if collect_diag:
                solver_stats["residual"].append(float("nan"))
                solver_stats["residual_history"].append([])
                solver_stats["stability"].append(None)
                solver_stats["verification"].append(None)
            continue


        sol = mod.inverse_hvp(mu)

        v = sol.get("ihvp")
        if v is None:
            v = torch.zeros(P, device=device)
        info = sol.get("iterations", 0)
        rr = sol.get("final_residual", float("nan"))
        v_c.append(v.detach())
        solver_stats["iters"].append(int(info))
        if collect_diag:
            solver_stats["residual"].append(float(rr))
            solver_stats["residual_history"].append(sol.get("residual_history", []))
            solver_stats["stability"].append(sol.get("stability_info"))
            solver_stats["verification"].append(sol.get("verification_info"))
        if logger:
            logger.info("Cluster %d: CG converged in %d iters, rr=%.2e", c, int(info), rr)
    t_cg = time.time() - t_cg0

    if not collect_diag:
        solver_stats.setdefault("residual", [])
        solver_stats.setdefault("residual_history", [])
        solver_stats.setdefault("stability", [])
        solver_stats.setdefault("verification", [])

    build_time = time.time() - start
    timings = {
        "time_grads_full": float(t_grads_full),
        "time_grads_proj": float(t_grads),
        "time_umap": float(t_umap),
        "time_kmeans": float(t_kmeans),
        "time_representatives": float(t_reps),
        "time_grads_reps": float(t_grep),
        "time_cluster_means": float(t_cluster_means),
        "time_cg_total": float(t_cg),
    }
    cluster_stats = {
        "size_min": int(min(counts) if counts else 0),
        "size_max": int(max(counts) if counts else 0),
        "size_mean": float(np.mean(counts) if counts else 0.0),
        "n_empty": int(sum(1 for v in counts if v == 0)),
        "counts_cv": float(counts_cv),
        "between_within_ratio": float(bwr),
        "silhouette": float(silhouette),
        "method": cluster_method,
        "within_vars": within_vars,
    }
    cache = {
        "N": N,
        "P": P,
        "labels": labels_np,
        "counts_all": counts,
        "v_c": v_c,
        "v_is_mean": True,
        "within_var": within_vars,
        "cluster_means": [m.detach() for m in mus],
        "cluster_method": cluster_method,
        "cfg": cfg,
        "build_time": build_time,
        "timings": timings,
        "cluster_stats": cluster_stats,
        "umap_used": bool(cfg.use_umap and int(cfg.umap_n_components) > 0),
        "proj_dim": int(Gk.shape[1]),
        "proj_grads": Gk_reduced,
        "proj_matrix": None,
        "jl_seed": int(cfg.seed) if cfg.jl_dim > 0 else None,
        "jl_rel_distortion_mean": float(jl_rel_mean),
        "jl_rel_distortion_p95": float(jl_rel_p95),
        "umap_proj_distortion_mean": float(proj_dist_mean),
        "umap_proj_distortion_p95": float(proj_dist_p95),
        "umap_trustworthiness": float(trust),
        "counts_cv": float(counts_cv),
        "between_within_ratio": float(bwr),
        "silhouette": float(silhouette),
        "solver_stats": solver_stats,
    }
    
    if logger:
        logger.info(f"IFC cache built in {build_time:.2f}s")
        
    return cache

def delta_ifc_from_cache(cache: Dict, indices_S: List[int], use_exact_den: bool = True, device: Optional[torch.device] = None) -> Dict:
    labels = cache["labels"]
    N = int(cache["N"])
    v_c: List[torch.Tensor] = cache["v_c"]
    C = len(v_c)
    v_is_mean = bool(cache.get("v_is_mean", True))
    counts_all = np.asarray(cache.get("counts_all", [1]*C), dtype=int)

    idx_S = np.asarray(indices_S, dtype=int)
    idx_S = idx_S[(idx_S >= 0) & (idx_S < len(labels))]
    lbl_S = labels[idx_S]
    m_c = np.bincount(lbl_S, minlength=C)

    # coefficients per cluster
    if v_is_mean:
        # v_c ≈ H^{-1} * (mean grad in cluster c)
        coeff = torch.as_tensor(m_c, device=v_c[0].device, dtype=v_c[0].dtype)
    else:
        # v_c ≈ H^{-1} * (sum grad in cluster c)
        Nc = torch.as_tensor(np.maximum(counts_all, 1), device=v_c[0].device, dtype=v_c[0].dtype)
        coeff = torch.as_tensor(m_c, device=v_c[0].device, dtype=v_c[0].dtype) / Nc

    delta = sum(coeff[i] * v_c[i] for i in range(C))

    denom = (N - int(m_c.sum())) if use_exact_den else N
    denom = max(1, denom)
    delta = delta / denom
    if device is not None:
        delta = delta.to(device)

    return {"delta_theta": delta.detach(), "counts_S": m_c.tolist()}
