#!/usr/bin/env python3
"""
Two estimators in one script:

1) IF-based quadratic upper bound of I(parameters phi; U | X^n)
   - U indicates whether a specific sample i is in validation (1) or training (0).
   - We use the diagonal empirical Fisher (EF) to form H^{-1} ≈ diag(1 / (EF + damping)) and evaluate
       Upper ≈ (1 / (2 n_train)) * E_{i,j} [ (g_i - g_j)^T H^{-1} (g_i - g_j) ].

2) Maximum-entropy covariance upper bound of I(Z_U; U | phi, X^n)
   - Z_U is the latent of a uniformly drawn training index U.
   - Let q_i(z) = N(z | mu_i, diag(var_i)) with (mu_i, logvar_i) from the encoder.
   - Define the mixture q_mix(z) = (1/n) Σ_i q_i(z). Then
       I(Z_U;U | phi,X^n) = H[q_mix] - (1/n) Σ_i H[q_i]
     and we upper-bound H[q_mix] by a Gaussian with covariance Σ_mix of q_mix:
       H[q_mix] ≤ 1/2 [ d log(2πe) + log det Σ_mix ].
   - This yields the practical bound
       I ≤ 1/2 [ log det Σ_mix - (1/n) Σ_i Σ_d log var_{i,d} ].

Select the mode via --mode {both, if_params_u, zu_upper} (default: both).
The script writes per-split results and a cross-split summary to JSON.
"""

import os
import sys
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader

# Add project root to sys.path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from utils.model_utils import load_trained_model, load_experiment_data_splits
from utils.data_utils import get_mnist_data, get_fashion_mnist_data
from model.vae_models.vae import vae_loss


@dataclass
class EstimatorConfig:
    experiment_dir: str
    device: str = "cpu"
    max_splits: Optional[int] = None
    max_samples_per_split: int = 200
    ef_max_train_samples: int = 1000
    damping: float = 1e-3
    save_filename: str = "mi_params_u.json"
    param_scope: str = "all"
    # New mode and options for Z_U MI (zu_upper)
    mode: str = "both"  # {"both", "if_params_u", "zu_upper"}
    z_batch_size: int = 256
    z_max_train_samples: int = -1  # -1: use all
    z_cov_jitter: float = 1e-6


def parse_args() -> EstimatorConfig:
    import argparse
    parser = argparse.ArgumentParser(description="Estimate MI: IF-based I(params;U|X^n) or max-ent upper bound for I(Z_U;U|phi,X^n)")
    parser.add_argument("--experiment_dir", type=str, required=True,
                        help="Path to experiment directory under results/experiments/.../")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"],
                        help="Computation device")
    parser.add_argument("--max_splits", type=int, default=None,
                        help="Process at most this many splits (from split_0 upward)")
    parser.add_argument("--mode", type=str, choices=["both", "if_params_u", "zu_upper"], default="both",
                        help="Select estimator: both, IF-based params MI only, or Z_U MI upper bound only")
    parser.add_argument("--max_samples_per_split", type=int, default=200,
                        help="Number of training samples per split used in the bound")
    parser.add_argument("--ef_max_train_samples", type=int, default=1000,
                        help="Samples to estimate EF diagonal (-1: use all training samples)")
    parser.add_argument("--damping", type=float, default=1e-3,
                        help="EF diagonal damping added to denominator for stability")
    parser.add_argument("--save_filename", type=str, default="mi_params_u.json",
                        help="Output JSON filename under experiment_dir")
    parser.add_argument("--param_scope", type=str, choices=["all", "encoder", "decoder"], default="all",
                        help="Parameter scope to include in IF (default: all)")
    # Z_U MI options
    parser.add_argument("--z_batch_size", type=int, default=256,
                        help="Batch size for encoder forward pass in zu_upper mode")
    parser.add_argument("--z_max_train_samples", type=int, default=-1,
                        help="Max train samples for Z_U MI (-1: use all)")
    parser.add_argument("--z_cov_jitter", type=float, default=1e-6,
                        help="Initial jitter added to Σ_mix for stable logdet")

    args = parser.parse_args()
    return EstimatorConfig(
        experiment_dir=args.experiment_dir,
        device=args.device,
        max_splits=args.max_splits,
        mode=args.mode,
        max_samples_per_split=args.max_samples_per_split,
        ef_max_train_samples=args.ef_max_train_samples,
        damping=args.damping,
        save_filename=args.save_filename,
        param_scope=args.param_scope,
        z_batch_size=args.z_batch_size,
        z_max_train_samples=args.z_max_train_samples,
        z_cov_jitter=args.z_cov_jitter,
    )


def zero_like_parameters(model: nn.Module) -> torch.Tensor:
    numel = sum(p.numel() for p in model.parameters())
    return torch.zeros(numel, dtype=torch.float32, device=next(model.parameters()).device)


def flatten_grad(model: nn.Module) -> torch.Tensor:
    grads: List[torch.Tensor] = []
    for p in model.parameters():
        if p.grad is None:
            grads.append(torch.zeros_like(p).view(-1))
        else:
            grads.append(p.grad.detach().view(-1))
    return torch.cat(grads)


def compute_empirical_fisher_diagonal(
    model: nn.Module,
    beta: float,
    dataset_subset: Subset,
    max_samples: int,
    device: torch.device,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Deterministically approximate the diagonal of the empirical Fisher matrix:
    EF_diag = E[ grad(l_i) ⊙ grad(l_i) ] over the training subset,
    using up to max_samples samples (first max_samples indices of the subset).
    """
    model.eval()  # freeze BN stats and disable dropout
    ef_diag: Optional[torch.Tensor] = None

    # Deterministic selection of samples
    indices: List[int] = list(dataset_subset.indices)
    # Interpret max_samples < 0 as "use all"
    if max_samples is not None and max_samples >= 0 and len(indices) > max_samples:
        indices = indices[:max_samples]

    for k, idx in enumerate(indices):
        x, _ = dataset_subset.dataset[idx]
        x = x.to(device).unsqueeze(0)  # shape (1, D)

        # forward
        recon, mu, logvar = model(x)
        loss, _, _ = vae_loss(recon, x, mu, logvar, beta)

        # backward
        model.zero_grad(set_to_none=True)
        loss.backward()
        g = flatten_grad(model)

        g2 = g.pow(2)
        if mask is not None:
            # restrict to selected parameter scope
            g2 = g2 * mask
        if ef_diag is None:
            ef_diag = g2
        else:
            ef_diag.add_(g2)

    if ef_diag is None:
        # No samples (should not happen); return zeros
        ef_diag = zero_like_parameters(model)
    else:
        ef_diag /= max(1, len(indices))

    return ef_diag


def compute_sample_gradient(
    model: nn.Module,
    beta: float,
    x: torch.Tensor,
) -> torch.Tensor:
    model.eval()
    recon, mu, logvar = model(x)
    loss, _, _ = vae_loss(recon, x, mu, logvar, beta)
    model.zero_grad(set_to_none=True)
    loss.backward()
    g = flatten_grad(model)
    return g


def build_param_mask(model: nn.Module, scope: str) -> Optional[torch.Tensor]:
    """
    Build a 1D mask over flattened parameters selecting a given scope.
    scope in {"all", "encoder", "decoder"}.
    Returns None if scope == "all" for fast path.
    """
    if scope == "all":
        return None
    device = next(model.parameters()).device
    mask_parts = []
    for name, p in model.named_parameters():
        include = False
        if scope == "encoder":
            include = name.startswith("encoder.")
        elif scope == "decoder":
            include = name.startswith("decoder.")
        part = torch.ones(p.numel(), dtype=torch.float32, device=device) if include else torch.zeros(p.numel(), dtype=torch.float32, device=device)
        mask_parts.append(part)
    return torch.cat(mask_parts) if mask_parts else None


def compute_if_upper_bound_for_split(
    model: nn.Module,
    beta: float,
    train_subset: Subset,
    ef_diag: torch.Tensor,
    damping: float,
    max_samples_per_split: int,
    device: torch.device,
    mask: Optional[torch.Tensor] = None,
) -> Dict:
    """
    IF-based quadratic upper bound approximation per the provided formula:

    UpperBound ≈ (1 / (2 n_train)) * E_{i,j} [ (g_i - g_j)^T H^{-1} (g_i - g_j) ]

    where H^{-1} is approximated by diag(1 / (EF_diag + damping)).

    We compute the pairwise expectation deterministically over the first M samples
    (M ≤ max_samples_per_split) of the training subset using a streaming formula:

      sum_{i,j} (g_i - g_j)^T W (g_i - g_j)
        = 2 M * sum_i g_i^T W g_i - 2 * (sum_i g_i)^T W (sum_i g_i),

    with W = diag(H^{-1}).
    """
    n_train = len(train_subset.indices)
    if n_train == 0:
        return {
            "num_train_samples": 0,
            "M": 0,
            "upper_bound": None,
        }

    # Deterministic selection of first M training samples
    indices: List[int] = list(train_subset.indices)
    # Interpret negative as "use all"
    if max_samples_per_split is not None and max_samples_per_split >= 0 and len(indices) > max_samples_per_split:
        indices = indices[:max_samples_per_split]
    M = len(indices)
    if M == 0:
        return {"num_train_samples": n_train, "M": 0, "upper_bound": None}

    # H^{-1} diagonal (on CPU for memory efficiency)
    H_inv_diag = (ef_diag + damping).reciprocal().detach()
    if mask is not None:
        H_inv_diag = H_inv_diag * mask
    H_inv_diag = H_inv_diag.cpu()

    sum_qi = 0.0  # Σ_i g_i^T W g_i
    # Accumulate Σ_i g_i (on CPU)
    sumG = torch.zeros_like(H_inv_diag)

    for idx in indices:
        x, _ = train_subset.dataset[idx]
        x = x.to(device).unsqueeze(0)
        g_i = compute_sample_gradient(model, beta, x).detach()
        if mask is not None:
            g_i = g_i * mask
        g_i = g_i.cpu()

        # g_i^T W g_i = Σ_k H_inv_diag[k] * g_i[k]^2
        sum_qi += torch.sum(H_inv_diag * (g_i * g_i)).item()
        sumG.add_(g_i)

    # (sum_i g_i)^T W (sum_i g_i)
    sumG2_w = torch.sum(H_inv_diag * (sumG * sumG)).item()

    # sum_{i,j} qf = 2 M * Σ_i g_i^T W g_i - 2 * (Σ_i g_i)^T W (Σ_i g_i)
    sum_pair = 2.0 * M * sum_qi - 2.0 * sumG2_w

    # E_{i,j}[qf] ≈ sum_pair / M^2
    mean_pair = sum_pair / (M * M)

    # Upper bound ≈ (1 / (2 n_train)) * E_{i,j}[qf]
    upper_bound = (1.0 / (2.0 * n_train)) * mean_pair

    return {
        "num_train_samples": n_train,
        "M": M,
        "sum_qi": sum_qi,
        "sumG2_w": sumG2_w,
        "mean_pair": mean_pair,
        "upper_bound": upper_bound,
    }


def _stable_logdet_psd(cov: torch.Tensor, init_jitter: float = 1e-6, max_tries: int = 6) -> Tuple[float, float]:
    """
    Compute logdet of a symmetric PSD matrix via Cholesky with increasing jitter.
    Returns (logdet, jitter_used).
    """
    assert cov.ndim == 2 and cov.shape[0] == cov.shape[1]
    d = cov.shape[0]
    eye = torch.eye(d, dtype=cov.dtype, device=cov.device)
    jitter = init_jitter
    for _ in range(max_tries):
        try:
            L = torch.linalg.cholesky(cov + jitter * eye)
            logdet = 2.0 * torch.log(torch.diag(L)).sum().item()
            return logdet, jitter
        except RuntimeError:
            jitter *= 10.0
            continue
    # Fallback: eigenvalue-based (more expensive but robust)
    evals = torch.linalg.eigvalsh(cov)
    eps = max(jitter, 1e-12)
    logdet = torch.log(torch.clamp(evals, min=eps)).sum().item()
    return logdet, jitter


def compute_mi_zu_upper_for_split(
    model: nn.Module,
    train_subset: Subset,
    device: torch.device,
    batch_size: int,
    max_train_samples: int,
    cov_jitter: float,
) -> Dict:
    """
    Compute the upper bound for I(Z_U; U | phi, X^n) using the maximum-entropy (covariance-only) bound.

    I_upper = 0.5 * [ log det Σ_mix - (1/n) * Σ_i Σ_d log var_{i,d} ]

    where Σ_mix = Cov_{U}[mu_U] + E_U[diag(var_U)]. We estimate terms by a single pass.
    """
    model.eval()
    # Determine indices to use
    indices: List[int] = list(train_subset.indices)
    if max_train_samples is not None and max_train_samples >= 0 and len(indices) > max_train_samples:
        indices = indices[:max_train_samples]
    if len(indices) == 0:
        return {"num_train_samples": 0, "mi_upper_bound": None}

    subset = Subset(train_subset.dataset, indices)
    loader = DataLoader(subset, batch_size=max(1, batch_size), shuffle=False)

    sum_mu: Optional[torch.Tensor] = None
    sum_outer_mu: Optional[torch.Tensor] = None
    sum_var_diag: Optional[torch.Tensor] = None
    sum_logvar_total: float = 0.0
    n_count: int = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            recon, mu, logvar = model(x)
            # Shapes: mu, logvar: (B, d)
            var = torch.exp(logvar)
            if sum_mu is None:
                d = mu.shape[1]
                sum_mu = torch.zeros(d, dtype=torch.float64, device=device)
                sum_outer_mu = torch.zeros(d, d, dtype=torch.float64, device=device)
                sum_var_diag = torch.zeros(d, dtype=torch.float64, device=device)
            # Accumulate in float64 for better numerical stability
            sum_mu = sum_mu + mu.to(torch.float64).sum(dim=0)
            sum_outer_mu = sum_outer_mu + mu.to(torch.float64).T @ mu.to(torch.float64)
            sum_var_diag = sum_var_diag + var.to(torch.float64).sum(dim=0)
            sum_logvar_total += logvar.sum().item()
            n_count += mu.shape[0]

    d = sum_mu.shape[0]  # type: ignore[union-attr]
    mu_mix = (sum_mu / max(1, n_count)).to(torch.float64)
    mean_outer_mu = (sum_outer_mu / max(1, n_count)).to(torch.float64)
    cov_mu = mean_outer_mu - torch.outer(mu_mix, mu_mix)
    # Symmetrize to avoid small asymmetry
    cov_mu = 0.5 * (cov_mu + cov_mu.T)
    mean_var_diag = (sum_var_diag / max(1, n_count))
    Sigma_mix = cov_mu + torch.diag(mean_var_diag)

    # Move to CPU for logdet for memory friendliness
    Sigma_mix_cpu = Sigma_mix.detach().cpu().to(torch.float64)
    logdet, jitter_used = _stable_logdet_psd(Sigma_mix_cpu, init_jitter=cov_jitter)

    avg_sum_logvar = sum_logvar_total / max(1, n_count)
    mi_upper = 0.5 * (logdet - float(avg_sum_logvar))

    return {
        "num_train_samples": n_count,
        "latent_dim": d,
        "logdet_sigma_mix": float(logdet),
        "avg_sum_logvar": float(avg_sum_logvar),
        "jitter_used": float(jitter_used),
        "mi_upper_bound": float(mi_upper),
    }


def main():
    cfg = parse_args()
    device = torch.device(cfg.device if (cfg.device == 'cuda' and torch.cuda.is_available()) else 'cpu')

    # Load splits and config
    splits, metadata = load_experiment_data_splits(cfg.experiment_dir)
    num_splits = metadata['num_splits']
    if cfg.max_splits is not None:
        num_splits = min(num_splits, cfg.max_splits)

    # Unified flow: compute both optionally and emit combined summary
    want_if = cfg.mode in ("both", "if_params_u")
    want_zu = cfg.mode in ("both", "zu_upper")

    per_split_if_upper: List[Dict] = []
    per_split_zu_upper: List[Dict] = []
    per_split_combined: List[Dict] = []

    for split_idx in range(num_splits):
        try:
            model, model_config = load_trained_model(cfg.experiment_dir, split_idx, device)
        except Exception as e:
            print(f"[WARN] Skip split {split_idx}: {e}")
            continue

        # Load dataset matching this experiment's dataset/arch
        dataset_name = model_config.get('data', {}).get('dataset', 'mnist')
        if dataset_name.lower() in ('fashion_mnist', 'fashion', 'fmnist'):
            train_dataset, _ = get_fashion_mnist_data()
        else:
            train_dataset, _ = get_mnist_data()

        train_indices, _ = splits[split_idx]
        train_subset = Subset(train_dataset, train_indices)

        this_if_val: Optional[float] = None
        this_zu_val: Optional[float] = None

        if want_if:
            beta = float(model_config['training']['beta'])
            mask = build_param_mask(model, scope=cfg.param_scope)
            ef_diag = compute_empirical_fisher_diagonal(
                model=model,
                beta=beta,
                dataset_subset=train_subset,
                max_samples=cfg.ef_max_train_samples,
                device=device,
                mask=mask,
            )
            if_upper = compute_if_upper_bound_for_split(
                model=model,
                beta=beta,
                train_subset=train_subset,
                ef_diag=ef_diag,
                damping=cfg.damping,
                max_samples_per_split=cfg.max_samples_per_split,
                device=device,
                mask=mask,
            )
            if_upper["split"] = split_idx
            per_split_if_upper.append(if_upper)
            this_if_val = if_upper.get("upper_bound")

        if want_zu:
            mi_info = compute_mi_zu_upper_for_split(
                model=model,
                train_subset=train_subset,
                device=device,
                batch_size=cfg.z_batch_size,
                max_train_samples=cfg.z_max_train_samples,
                cov_jitter=cfg.z_cov_jitter,
            )
            mi_info["split"] = split_idx
            per_split_zu_upper.append(mi_info)
            this_zu_val = mi_info.get("mi_upper_bound")

        combined_val = None
        if (this_if_val is not None) and (this_zu_val is not None):
            combined_val = float(this_if_val) + float(this_zu_val)
        per_split_combined.append({
            "split": split_idx,
            "if_params_u_upper": (float(this_if_val) if this_if_val is not None else None),
            "zu_upper": (float(this_zu_val) if this_zu_val is not None else None),
            "combined_upper": (float(combined_val) if combined_val is not None else None),
        })

    # Aggregations
    if_summary = None
    if want_if and len(per_split_if_upper) > 0:
        vals_if = [d.get("upper_bound") for d in per_split_if_upper if d.get("upper_bound") is not None]
        if len(vals_if) > 0:
            if_summary = {
                "mean_upper_bound": float(np.mean(vals_if)),
                "std_upper_bound": float(np.std(vals_if)),
                "num_splits": len(vals_if),
            }

    zu_summary = None
    if want_zu and len(per_split_zu_upper) > 0:
        vals_zu = [d.get("mi_upper_bound") for d in per_split_zu_upper if d.get("mi_upper_bound") is not None]
        if len(vals_zu) > 0:
            zu_summary = {
                "mean_upper_bound": float(np.mean(vals_zu)),
                "std_upper_bound": float(np.std(vals_zu)),
                "num_splits": len(vals_zu),
            }

    combined_summary = None
    if (want_if and want_zu) and len(per_split_combined) > 0:
        vals_comb = [d.get("combined_upper") for d in per_split_combined if d.get("combined_upper") is not None]
        if len(vals_comb) > 0:
            combined_summary = {
                "mean_upper_bound": float(np.mean(vals_comb)),
                "std_upper_bound": float(np.std(vals_comb)),
                "num_splits": len(vals_comb),
            }

    result: Dict[str, object] = {"estimator_config": asdict(cfg)}
    if want_if:
        result["if_upper_bounds_per_split"] = per_split_if_upper
        result["if_upper_bounds_summary"] = if_summary
    if want_zu:
        result["mi_zu_upper_per_split"] = per_split_zu_upper
        result["mi_zu_upper_summary"] = zu_summary
    if want_if and want_zu:
        result["combined_per_split"] = per_split_combined
        result["combined_summary"] = combined_summary

    # Filename policy
    final_filename = cfg.save_filename
    if cfg.mode == "both" and cfg.save_filename == "mi_params_u.json":
        final_filename = f"mi_both_{cfg.param_scope}.json"
    elif cfg.mode == "if_params_u" and cfg.save_filename == "mi_params_u.json":
        final_filename = f"mi_params_u_{cfg.param_scope}.json"
    elif cfg.mode == "zu_upper" and cfg.save_filename == "mi_params_u.json":
        final_filename = "mi_zu_upper.json"

    save_path = os.path.join(cfg.experiment_dir, final_filename)
    with open(save_path, "w") as f:
        json.dump(result, f, indent=2)
    print(f"Saved MI results to: {save_path}")


if __name__ == "__main__":
    main()


