#!/usr/bin/env python3
"""
Estimate an IF-based quadratic upper bound approximation of I(parameters; U | X^n)
using efficient HVP + Conjugate Gradient per Koh & Liang (2017, Sec. 3).

We approximate E_{i,j}[(g_i - g_j)^T H^{-1} (g_i - g_j)] by Monte Carlo over
K random pairs (i, j), computing H^{-1} v with v = (g_i - g_j) via CG and
Hessian-vector products (Pearlmutter's trick). This avoids forming/storing H.
"""
import os
import sys
import json
import random
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader

# Add project root to sys.path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from utils.model_utils import load_trained_model, load_experiment_data_splits
from utils.data_utils import get_mnist_data
from model.vae_models.vae import vae_loss_per_sample
from utils.influence import s_test


@dataclass
class EstimatorConfig:
    experiment_dir: str
    device: str = "cpu"
    max_splits: Optional[int] = None
    max_grad_samples_per_split: int = 200
    num_pair_samples: int = 500
    hvp_max_train_samples: int = 512
    damping: float = 1e-2
    scale: float = 25.0
    cg_tol: float = 1e-5
    cg_max_iter: int = 500
    batch_size: int = 128
    save_filename: str = "mi_params_u_hvp.json"
    param_scope: str = "all"
    estimation_mode: str = "pairs"  # 'pairs' or 'mean'


def parse_args() -> EstimatorConfig:
    import argparse
    parser = argparse.ArgumentParser(description="Estimate IF-based MI upper bound with HVP+CG")
    parser.add_argument("--experiment_dir", type=str, required=True,
                        help="Path to experiment directory under results/experiments/.../")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"],
                        help="Computation device")
    parser.add_argument("--max_splits", type=int, default=None,
                        help="Process at most this many splits (from split_0 upward)")
    parser.add_argument("--max_grad_samples_per_split", type=int, default=200,
                        help="Number of training samples whose gradients are used to form pair vectors")
    parser.add_argument("--num_pair_samples", type=int, default=500,
                        help="Number of random (i,j) pairs to estimate E_{i,j}")
    parser.add_argument("--hvp_max_train_samples", type=int, default=512,
                        help="Samples to define HVP loss (averaged over this many train examples)")
    parser.add_argument("--damping", type=float, default=1e-2,
                        help="Damping added to Hessian in CG (H + damping*I)")
    parser.add_argument("--scale", type=float, default=25.0,
                        help="Optional scale divider in s_test to improve CG conditioning")
    parser.add_argument("--cg_tol", type=float, default=1e-5,
                        help="CG tolerance")
    parser.add_argument("--cg_max_iter", type=int, default=500,
                        help="CG maximum iterations")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="Batch size for loss_fn_for_hvp evaluation")
    parser.add_argument("--save_filename", type=str, default="mi_params_u_hvp.json",
                        help="Output JSON filename under experiment_dir")
    parser.add_argument("--param_scope", type=str, choices=["all", "encoder", "decoder"], default="all",
                        help="Parameter scope to include in IF (default: all)")
    parser.add_argument("--estimation_mode", type=str, choices=["pairs", "mean"], default="mean",
                        help="'pairs': Monte Carlo over (i,j); 'mean': 2E[g^T H^{-1} g]-2(E[g])^T H^{-1}E[g]")

    args = parser.parse_args()
    return EstimatorConfig(
        experiment_dir=args.experiment_dir,
        device=args.device,
        max_splits=args.max_splits,
        max_grad_samples_per_split=args.max_grad_samples_per_split,
        num_pair_samples=args.num_pair_samples,
        hvp_max_train_samples=args.hvp_max_train_samples,
        damping=args.damping,
        scale=args.scale,
        cg_tol=args.cg_tol,
        cg_max_iter=args.cg_max_iter,
        batch_size=args.batch_size,
        save_filename=args.save_filename,
        param_scope=args.param_scope,
        estimation_mode=args.estimation_mode,
    )


def flatten_grad(model: nn.Module) -> torch.Tensor:
    grads: List[torch.Tensor] = []
    for p in model.parameters():
        if p.grad is None:
            grads.append(torch.zeros_like(p).view(-1))
        else:
            grads.append(p.grad.detach().view(-1))
    return torch.cat(grads)


def build_param_mask(model: nn.Module, scope: str) -> Optional[torch.Tensor]:
    if scope == "all":
        return None
    device = next(model.parameters()).device
    mask_parts = []
    for name, p in model.named_parameters():
        include = False
        if scope == "encoder":
            include = name.startswith("encoder.")
        elif scope == "decoder":
            include = name.startswith("decoder.")
        part = torch.ones(p.numel(), dtype=torch.float32, device=device) if include else torch.zeros(p.numel(), dtype=torch.float32, device=device)
        mask_parts.append(part)
    return torch.cat(mask_parts) if mask_parts else None


def compute_sample_gradient(model: nn.Module, beta: float, x: torch.Tensor) -> torch.Tensor:
    model.eval()
    recon, mu, logvar = model(x)
    per_sample = vae_loss_per_sample(recon, x, mu, logvar, beta)
    loss = per_sample.mean()
    model.zero_grad(set_to_none=True)
    loss.backward()
    return flatten_grad(model)


def main() -> None:
    cfg = parse_args()
    device = torch.device(cfg.device if (cfg.device == 'cuda' and torch.cuda.is_available()) else 'cpu')

    # Load splits and config
    splits, metadata = load_experiment_data_splits(cfg.experiment_dir)
    num_splits = metadata['num_splits']
    if cfg.max_splits is not None:
        num_splits = min(num_splits, cfg.max_splits)

    train_dataset, _ = get_mnist_data()

    results: List[Dict] = []

    for split_idx in range(num_splits):
        try:
            model, model_config = load_trained_model(cfg.experiment_dir, split_idx, device)
        except Exception as e:
            print(f"[WARN] Skip split {split_idx}: {e}")
            continue

        beta = float(model_config['training']['beta'])
        mask = build_param_mask(model, scope=cfg.param_scope)

        train_indices, _ = splits[split_idx]
        n_train = len(train_indices)
        if n_train == 0:
            results.append({"split": split_idx, "upper_bound": None, "num_train_samples": 0, "num_pairs": 0})
            continue

        # Select grad samples deterministically; -1 means use all
        if cfg.max_grad_samples_per_split is not None and cfg.max_grad_samples_per_split < 0:
            M = n_train
        else:
            M = min(n_train, max(1, cfg.max_grad_samples_per_split))
        sel_indices = train_indices[:M]

        # Precompute gradients g_i on device, optionally mask
        grads: List[torch.Tensor] = []
        for idx in sel_indices:
            x, _ = train_dataset[idx]
            x = x.to(device).unsqueeze(0)
            g_i = compute_sample_gradient(model, beta, x)
            if mask is not None:
                g_i = g_i * mask
            grads.append(g_i.detach())

        # Prepare loss_fn for HVP (average over a subset of train examples)
        # HVP loss subset; -1 means use all
        if cfg.hvp_max_train_samples is not None and cfg.hvp_max_train_samples < 0:
            subset_for_hvp = train_indices
        else:
            subset_for_hvp = train_indices[:min(n_train, max(1, cfg.hvp_max_train_samples))]
        subset_dataset = Subset(train_dataset, subset_for_hvp)
        hvp_loader = DataLoader(subset_dataset, batch_size=cfg.batch_size, shuffle=False)

        def loss_fn_for_hvp(model_: nn.Module) -> torch.Tensor:
            total = 0.0
            count = 0
            for data, _ in hvp_loader:
                data = data.to(device)
                recon, mu, logvar = model_(data)
                per_sample = vae_loss_per_sample(recon, data, mu, logvar, beta)
                total = total + per_sample.mean()
                count += 1
            return total / max(count, 1)

        if cfg.estimation_mode == "pairs":
            # Monte Carlo over K pairs
            K = min(cfg.num_pair_samples, M * M)
            rng = random.Random(42)
            sum_qf = 0.0
            for _ in range(K):
                i = rng.randrange(M)
                j = rng.randrange(M)
                v = (grads[i] - grads[j]).detach()
                s = s_test(
                    model,
                    loss_fn_for_hvp,
                    v,
                    damp=cfg.damping,
                    scale=cfg.scale,
                    cg_tol=cfg.cg_tol,
                    cg_max_iter=cfg.cg_max_iter,
                    mask=mask,
                )
                qf = torch.dot(v, s).item()
                sum_qf += qf
            mean_pair = sum_qf / max(1, K)
            K_used = K
        else:
            # Pair-free estimator: 2 E[g^T H^{-1} g] - 2 (E[g])^T H^{-1} (E[g])
            sum_qi = 0.0
            for i in range(M):
                g = grads[i]
                s_g = s_test(
                    model,
                    loss_fn_for_hvp,
                    g,
                    damp=cfg.damping,
                    scale=cfg.scale,
                    cg_tol=cfg.cg_tol,
                    cg_max_iter=cfg.cg_max_iter,
                    mask=mask,
                )
                sum_qi += torch.dot(g, s_g).item()
            Eg_qEg = (sum_qi / max(1, M))
            g_mean = torch.stack(grads, dim=0).mean(dim=0)
            s_mean = s_test(
                model,
                loss_fn_for_hvp,
                g_mean,
                damp=cfg.damping,
                scale=cfg.scale,
                cg_tol=cfg.cg_tol,
                cg_max_iter=cfg.cg_max_iter,
                mask=mask,
            )
            mean_pair = 2.0 * Eg_qEg - 2.0 * torch.dot(g_mean, s_mean).item()
            K_used = None
        upper_bound = (1.0 / (2.0 * n_train)) * mean_pair

        results.append({
            "split": split_idx,
            "num_train_samples": n_train,
            "M": M,
            "num_pairs": (K_used if cfg.estimation_mode == "pairs" else None),
            "mean_pair": mean_pair,
            "upper_bound": upper_bound,
        })

    # Aggregate
    summary = None
    vals = [d["upper_bound"] for d in results if d.get("upper_bound") is not None]
    if len(vals) > 0:
        summary = {
            "mean_upper_bound": float(np.mean(vals)),
            "std_upper_bound": float(np.std(vals)),
            "num_splits": len(vals),
        }

    out = {
        "estimator_config": asdict(cfg),
        "mi_upper_bounds_per_split": results,
        "mi_upper_bounds_summary": summary,
    }

    # Disambiguate filename by scope if using default name
    final_filename = cfg.save_filename
    if cfg.save_filename == "mi_params_u_hvp.json":
        final_filename = f"mi_params_u_hvp_{cfg.param_scope}.json"

    save_path = os.path.join(cfg.experiment_dir, final_filename)
    with open(save_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"Saved MI (HVP) results to: {save_path}")


if __name__ == "__main__":
    main()


