#!/usr/bin/env python3
"""
Sweep EF-based hierarchical MI estimation over train sizes and depths.

This script discovers hierarchical VAE experiment directories and, for each
experiment, runs scripts/estimate_mi_hierarchical_ef.py for one or more depth
values l in {1..L}, where L is inferred from the experiment config unless an
explicit depth list is provided.
"""

import os
import sys
import subprocess
from pathlib import Path
from typing import List, Sequence, Optional

import yaml
import numpy as np


PROJECT_ROOT = Path(__file__).resolve().parent.parent


def load_yaml(path: str) -> dict:
    with open(path, 'r') as f:
        return yaml.safe_load(f)


def discover_experiment_dirs(base_cfg: dict, train_size: int) -> List[str]:
    """
    Find all experiment directories whose name starts with train{train_size}_ under
    <results.save_dir>/<dataset>/**, and that are hierarchical (arch in {hmlp, hcnn}).
    """
    root = Path(base_cfg['results']['save_dir']) / base_cfg['data']['dataset']
    pattern = f"train{train_size}_*"
    candidates: List[Path] = []
    if root.is_dir():
        for p in root.rglob(pattern):
            if p.is_dir():
                aggr = p / 'aggregated_results.json'
                splits_meta = p / 'data_splits' / 'experiment_metadata.json'
                if aggr.exists() and splits_meta.exists():
                    try:
                        with open(aggr, 'r') as f:
                            agg = yaml.safe_load(f)
                        arch = agg['config']['model'].get('arch', 'mlp')
                        if arch in {'hmlp', 'hcnn'}:
                            candidates.append(p)
                    except Exception:
                        # Skip malformed configs
                        pass
    return sorted(str(p) for p in candidates)


def logspaced_integers(min_value: int, max_value: int, num: int) -> List[int]:
    xs = np.logspace(np.log10(min_value), np.log10(max_value), num=num)
    vals = sorted({int(round(x)) for x in xs})
    return [v for v in vals if min_value <= v <= max_value]


def run_cmd(cmd: Sequence[str], env: dict = None) -> int:
    print("$", " ".join(cmd))
    proc = subprocess.run(cmd, env=env)
    return proc.returncode


def infer_depths_from_experiment(experiment_dir: str) -> List[int]:
    aggr = Path(experiment_dir) / 'aggregated_results.json'
    with open(aggr, 'r') as f:
        data = yaml.safe_load(f)
    latent_dims = data['config']['model'].get('latent_dims')
    if isinstance(latent_dims, list) and len(latent_dims) > 0:
        L = len(latent_dims)
        return list(range(1, L + 1))
    # Fallback to single-layer
    return [1]


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser(description='Sweep hierarchical EF MI over train sizes and depths')
    parser.add_argument('--config', type=str, required=True, help='Path to YAML config (to infer results root/dataset)')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cpu')
    parser.add_argument('--sizes', type=int, nargs='*', default=None, help='Explicit list of train_size values')
    parser.add_argument('--num_points', type=int, default=5)
    parser.add_argument('--min_size', type=int, default=1000)
    parser.add_argument('--max_size', type=int, default=30000)
    # Depth control
    parser.add_argument('--depths', type=int, nargs='*', default=None, help='Explicit list of l values; defaults to 1..L per exp')
    # Estimator controls
    parser.add_argument('--mode', type=str, choices=['both', 'if_params_u', 'zu_upper'], default='both')
    parser.add_argument('--ef_max_splits', type=int, default=None)
    parser.add_argument('--ef_max_samples_per_split', type=int, default=200)
    parser.add_argument('--ef_max_train_samples', type=int, default=1000)
    parser.add_argument('--damping', type=float, default=1e-3)
    parser.add_argument('--z_batch_size', type=int, default=512)
    parser.add_argument('--z_cov_jitter', type=float, default=1e-6)
    parser.add_argument('--save_filename', type=str, default=None, help='Override output filename for estimator')

    args = parser.parse_args()
    base_cfg = load_yaml(args.config)

    if args.sizes:
        train_sizes = sorted(set(int(s) for s in args.sizes))
    else:
        train_sizes = logspaced_integers(args.min_size, args.max_size, args.num_points)

    print(f"Train sizes (hierarchical MI sweep): {train_sizes}")

    env = os.environ.copy()
    env['VAE_DEVICE'] = args.device

    for ts in train_sizes:
        exp_dirs = discover_experiment_dirs(base_cfg, ts)
        if not exp_dirs:
            print(f"[WARN] No hierarchical experiment directories found for train_size={ts}")
            continue
        print(f"\n=== Hierarchical MI (EF) for train_size={ts} ({len(exp_dirs)} exp(s)) ===")
        for exp_dir in exp_dirs:
            # Depths per experiment
            depths: List[int] = list(sorted(set(args.depths))) if args.depths else infer_depths_from_experiment(exp_dir)
            print(f" -> {exp_dir} | depths={depths}")
            for l in depths:
                cmd = [
                    sys.executable,
                    str(PROJECT_ROOT / 'scripts' / 'estimate_mi_hierarchical_ef.py'),
                    '--experiment_dir', exp_dir,
                    '--device', args.device,
                    '--depth_l', str(l),
                    '--mode', args.mode,
                    '--max_samples_per_split', str(args.ef_max_samples_per_split),
                    '--ef_max_train_samples', str(args.ef_max_train_samples),
                    '--damping', str(args.damping),
                    '--z_batch_size', str(args.z_batch_size),
                    '--z_cov_jitter', str(args.z_cov_jitter),
                ]
                if args.ef_max_splits is not None:
                    cmd += ['--max_splits', str(args.ef_max_splits)]
                if args.save_filename is not None:
                    cmd += ['--save_filename', args.save_filename]
                rc = run_cmd(cmd, env=env)
                if rc != 0:
                    print(f"[WARN] MI estimation failed for exp_dir={exp_dir}, l={l} (rc={rc}).")


if __name__ == '__main__':
    main()


