#!/usr/bin/env python3
"""
Sweep training over multiple train_size values and run evaluation for each.

This script:
 1) Loads a base YAML config
 2) Generates a sequence of train_size values (log-spaced by default)
 3) For each train_size, runs the existing training script
 4) Runs the evaluation script on the resulting experiment directory

All outputs follow the existing directory structure used by train_vae_mnist.py.
"""

import os
import sys
import subprocess
from pathlib import Path
from typing import List, Sequence

import yaml
import numpy as np


PROJECT_ROOT = Path(__file__).resolve().parent.parent


def load_yaml(path: str) -> dict:
    with open(path, 'r') as f:
        return yaml.safe_load(f)


def compute_experiment_dir(base_cfg: dict, train_size: int) -> str:
    dataset_name = base_cfg['data']['dataset']
    # Build model name consistent with train_vae_mnist.py
    arch = base_cfg['model'].get('arch', 'mlp')
    obj = base_cfg['training'].get('objective', 'elbo')
    prefix = 'iwae' if obj == 'iwae' else 'vae'

    if arch == 'cnn':
        enc_ch = base_cfg['model'].get('encoder_channels') or []
        dec_ch = base_cfg['model'].get('decoder_channels') or []
        if enc_ch == dec_ch:
            hidden_tag = "_ch" + '_'.join(map(str, enc_ch))
        else:
            hidden_tag = "_encch" + '_'.join(map(str, enc_ch)) + "_decch" + '_'.join(map(str, dec_ch))
        model_name = f"{prefix}_cnn_latent{base_cfg['model']['latent_dim']}{hidden_tag}"
    elif arch in {'hmlp', 'hcnn'}:
        latent_dims = base_cfg['model'].get('latent_dims') or [base_cfg['model'].get('latent_dim')]
        ld_tag = '_'.join(map(str, latent_dims))
        if arch == 'hcnn':
            enc_ch = base_cfg['model'].get('encoder_channels') or []
            dec_ch = base_cfg['model'].get('decoder_channels') or []
            hidden_tag = "_encch" + '_'.join(map(str, enc_ch)) + "_decch" + '_'.join(map(str, dec_ch))
            model_name = f"{prefix}_hcnn_latent{ld_tag}{hidden_tag}"
        else:
            enc_dims = base_cfg['model'].get('encoder_hidden_dims') or base_cfg['model']['hidden_dims']
            dec_dims = base_cfg['model'].get('decoder_hidden_dims') or base_cfg['model']['hidden_dims']
            if enc_dims == dec_dims:
                hidden_tag = "_hidden" + '_'.join(map(str, enc_dims))
            else:
                hidden_tag = "_enc" + '_'.join(map(str, enc_dims)) + "_dec" + '_'.join(map(str, dec_dims))
            model_name = f"{prefix}_hmlp_latent{ld_tag}{hidden_tag}"
    else:
        enc_dims = base_cfg['model'].get('encoder_hidden_dims') or base_cfg['model']['hidden_dims']
        dec_dims = base_cfg['model'].get('decoder_hidden_dims') or base_cfg['model']['hidden_dims']
        if enc_dims == dec_dims:
            hidden_tag = "_hidden" + '_'.join(map(str, enc_dims))
        else:
            hidden_tag = "_enc" + '_'.join(map(str, enc_dims)) + "_dec" + '_'.join(map(str, dec_dims))
        model_name = f"{prefix}_latent{base_cfg['model']['latent_dim']}{hidden_tag}"

    obj_tag = f"_{obj}"
    if obj == 'iwae':
        k = base_cfg['training'].get('iwae_k', 5)
        if isinstance(k, list):
            obj_tag += "K" + 'x'.join(map(str, k))
        else:
            obj_tag += f"k{int(k)}"

    experiment_id = f"train{train_size}{obj_tag}_beta{base_cfg['training']['beta']}_lr{base_cfg['training']['learning_rate']}"
    base_results_dir = base_cfg['results']['save_dir']
    return str(Path(base_results_dir) / dataset_name / model_name / experiment_id)


def logspaced_integers(min_value: int, max_value: int, num: int) -> List[int]:
    xs = np.logspace(np.log10(min_value), np.log10(max_value), num=num)
    vals = sorted({int(round(x)) for x in xs})
    return [v for v in vals if min_value <= v <= max_value]


def run_cmd(cmd: Sequence[str], env: dict = None) -> int:
    print("$", " ".join(cmd))
    proc = subprocess.run(cmd, env=env)
    return proc.returncode


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser(description='Sweep train_size values and evaluate each experiment')
    parser.add_argument('--config', type=str, required=True, help='Path to YAML config')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default=None, help='Device for training')
    parser.add_argument('--sizes', type=int, nargs='*', default=None, help='Explicit list of train_size values')
    parser.add_argument('--num_points', type=int, default=7, help='Number of log-spaced points when --sizes not given')
    parser.add_argument('--min_size', type=int, default=1000, help='Min train_size for log-spacing')
    parser.add_argument('--max_size', type=int, default=30000, help='Max train_size for log-spacing')
    parser.add_argument('--skip_eval', action='store_true', help='If set, skip the evaluation step')
    args = parser.parse_args()

    base_cfg = load_yaml(args.config)

    if args.sizes:
        train_sizes = sorted(set(int(s) for s in args.sizes))
    else:
        train_sizes = logspaced_integers(args.min_size, args.max_size, args.num_points)

    print(f"Train sizes: {train_sizes}")

    # Environment for device selection
    env = os.environ.copy()
    if args.device in {'cpu', 'cuda'}:
        env['VAE_DEVICE'] = args.device

    for ts in train_sizes:
        print(f"\n=== Training with train_size={ts} ===")
        # Train
        rc = run_cmd([
            sys.executable,
            str(PROJECT_ROOT / 'scripts' / 'train_vae_mnist.py'),
            '--config', args.config,
            '--train_size', str(ts),
        ], env=env)
        if rc != 0:
            print(f"Training failed for train_size={ts} (rc={rc}). Skipping evaluation.")
            continue

        if args.skip_eval:
            continue

        # Evaluate
        exp_dir = compute_experiment_dir(base_cfg, ts)
        print(f"Evaluating experiment: {exp_dir}")
        rc = run_cmd([
            sys.executable,
            str(PROJECT_ROOT / 'scripts' / 'evaluate_experiment.py'),
            '--experiment_dir', exp_dir,
        ], env=env)
        if rc != 0:
            print(f"Evaluation failed for train_size={ts} (rc={rc}).")


if __name__ == '__main__':
    main()


