#!/usr/bin/env python3
"""
Sweep MI estimation over multiple train_size values.

Supports two estimators:
  - EF:  scripts/estimate_mi_params_u.py
  - HVP: scripts/estimate_mi_params_u_hvp.py

Enhancements:
- Robust experiment directory discovery: for each train_size, scan results tree
  under <results.save_dir>/<dataset>/**/train{size}_* so VAE/IWAE, MLP/CNN all work.
- For EF estimator, allow computing both MI terms in one run via --mi_mode (default: both).
"""
import os
import sys
import subprocess
from pathlib import Path
from typing import List, Sequence, Iterable

import yaml
import numpy as np


PROJECT_ROOT = Path(__file__).resolve().parent.parent


def load_yaml(path: str) -> dict:
    with open(path, 'r') as f:
        return yaml.safe_load(f)


def discover_experiment_dirs(base_cfg: dict, train_size: int) -> List[str]:
    """
    Find all experiment directories whose name starts with train{train_size}_ under
    <results.save_dir>/<dataset>/**.
    """
    root = Path(base_cfg['results']['save_dir']) / base_cfg['data']['dataset']
    pattern = f"train{train_size}_*"
    candidates: List[Path] = []
    if root.is_dir():
        for p in root.rglob(pattern):
            if p.is_dir():
                # Validate expected files
                aggr = p / 'aggregated_results.json'
                splits_meta = p / 'data_splits' / 'experiment_metadata.json'
                if aggr.exists() and splits_meta.exists():
                    candidates.append(p)
    # Fallback: try single path construction (legacy layout)
    if not candidates:
        try:
            model_name = f"vae_latent{base_cfg['model']['latent_dim']}_hidden{'_'.join(map(str, base_cfg['model']['hidden_dims']))}"
            experiment_id = f"train{train_size}_beta{base_cfg['training']['beta']}_lr{base_cfg['training']['learning_rate']}"
            legacy = root / model_name / experiment_id
            if legacy.is_dir():
                candidates.append(legacy)
        except Exception:
            pass
    return sorted(str(p) for p in candidates)


def logspaced_integers(min_value: int, max_value: int, num: int) -> List[int]:
    xs = np.logspace(np.log10(min_value), np.log10(max_value), num=num)
    vals = sorted({int(round(x)) for x in xs})
    return [v for v in vals if min_value <= v <= max_value]


def run_cmd(cmd: Sequence[str], env: dict = None) -> int:
    print("$", " ".join(cmd))
    proc = subprocess.run(cmd, env=env)
    return proc.returncode


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser(description='Sweep MI estimation (EF/HVP) over train_size values')
    parser.add_argument('--config', type=str, required=True, help='Path to YAML config (to infer experiment_dir)')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cpu', help='Device for MI computation')
    parser.add_argument('--sizes', type=int, nargs='*', default=None, help='Explicit list of train_size values')
    parser.add_argument('--num_points', type=int, default=7, help='Number of log-spaced points when --sizes not given')
    parser.add_argument('--min_size', type=int, default=1000, help='Min train_size for log-spacing')
    parser.add_argument('--max_size', type=int, default=30000, help='Max train_size for log-spacing')
    parser.add_argument('--estimator', type=str, choices=['ef', 'hvp'], default='ef', help='Which MI estimator to run')
    # EF general options
    parser.add_argument('--mi_mode', type=str, choices=['both', 'if_params_u', 'zu_upper'], default='both', help='EF script MI mode')
    parser.add_argument('--ef_save_filename', type=str, default=None, help='Override output filename for EF script')
    # EF options
    parser.add_argument('--ef_max_splits', type=int, default=None)
    parser.add_argument('--ef_max_samples_per_split', type=int, default=200)
    parser.add_argument('--ef_max_train_samples', type=int, default=1000)
    parser.add_argument('--ef_damping', type=float, default=1e-3)
    parser.add_argument('--ef_param_scope', type=str, choices=['all', 'encoder', 'decoder'], default='all')
    # HVP options
    parser.add_argument('--hvp_max_grad_samples_per_split', type=int, default=200)
    parser.add_argument('--hvp_num_pair_samples', type=int, default=500)
    parser.add_argument('--hvp_max_train_samples', type=int, default=512)
    parser.add_argument('--hvp_damping', type=float, default=1e-2)
    parser.add_argument('--hvp_scale', type=float, default=25.0)
    parser.add_argument('--hvp_cg_tol', type=float, default=1e-5)
    parser.add_argument('--hvp_cg_max_iter', type=int, default=500)
    parser.add_argument('--hvp_batch_size', type=int, default=128)
    parser.add_argument('--hvp_param_scope', type=str, choices=['all', 'encoder', 'decoder'], default='all')

    args = parser.parse_args()
    print("[sweep_mi] parsed args", flush=True)

    base_cfg = load_yaml(args.config)
    print("[sweep_mi] loaded YAML", flush=True)

    if args.sizes:
        train_sizes = sorted(set(int(s) for s in args.sizes))
    else:
        train_sizes = logspaced_integers(args.min_size, args.max_size, args.num_points)

    print(f"Train sizes (MI sweep): {train_sizes}", flush=True)

    env = os.environ.copy()
    env['VAE_DEVICE'] = args.device

    for ts in train_sizes:
        exp_dirs = discover_experiment_dirs(base_cfg, ts)
        if not exp_dirs:
            print(f"[WARN] No experiment directories found for train_size={ts}")
            continue
        print(f"\n=== MI estimation for train_size={ts} ({len(exp_dirs)} exp(s)) ===")
        for exp_dir in exp_dirs:
            print(f" -> {exp_dir}")
            if args.estimator == 'ef':
                cmd = [
                    sys.executable,
                    str(PROJECT_ROOT / 'scripts' / 'estimate_mi_params_u.py'),
                    '--experiment_dir', exp_dir,
                    '--device', args.device,
                    '--mode', args.mi_mode,
                    '--max_samples_per_split', str(args.ef_max_samples_per_split),
                    '--ef_max_train_samples', str(args.ef_max_train_samples),
                    '--damping', str(args.ef_damping),
                    '--param_scope', args.ef_param_scope,
                    # Z_U options (harmless if mode != zu_upper)
                    '--z_max_train_samples', str(args.ef_max_train_samples),
                    '--z_batch_size', '512',
                    '--z_cov_jitter', '1e-6',
                ]
                if args.ef_max_splits is not None:
                    cmd += ['--max_splits', str(args.ef_max_splits)]
                if args.ef_save_filename is not None:
                    cmd += ['--save_filename', args.ef_save_filename]
            else:
                cmd = [
                    sys.executable,
                    str(PROJECT_ROOT / 'scripts' / 'estimate_mi_params_u_hvp.py'),
                    '--experiment_dir', exp_dir,
                    '--device', args.device,
                    '--max_grad_samples_per_split', str(args.hvp_max_grad_samples_per_split),
                    '--num_pair_samples', str(args.hvp_num_pair_samples),
                    '--hvp_max_train_samples', str(args.hvp_max_train_samples),
                    '--damping', str(args.hvp_damping),
                    '--scale', str(args.hvp_scale),
                    '--cg_tol', str(args.hvp_cg_tol),
                    '--cg_max_iter', str(args.hvp_cg_max_iter),
                    '--batch_size', str(args.hvp_batch_size),
                    '--param_scope', args.hvp_param_scope,
                ]
            rc = run_cmd(cmd, env=env)
            if rc != 0:
                print(f"[WARN] MI estimation failed for exp_dir={exp_dir} (rc={rc}).")


if __name__ == '__main__':
    main()


