# file: user_extensions/experiments/run_hpo.py
import argparse
from pathlib import Path
import subprocess
import sys
import time
import yaml

import optuna
import torch

from prism.utils.config import load_config, process_derived_config
from user_extensions.experiments.utils import hpo_objective


def run_worker(args, study_plan):
    base_config = load_config(args.final_base_config)

    storage_name = f"sqlite:///{args.study_dir}/hpo_study.db"
    study = optuna.load_study(study_name=study_plan['study_name'], storage=storage_name)

    logger_info = {'save_dir': args.study_dir, 'name': 'hpo_trials', 'progress_bar': True}
    search_space = study_plan['search_space']

    objective_fn = lambda trial: hpo_objective(
        trial, base_config, search_space, logger_info, args.device_id, resume=args.resume
    )
    study.optimize(objective_fn, n_trials=args.n_trials_per_worker)


def main():
    parser = argparse.ArgumentParser(description="Run a multi-objective HPO study.")
    parser.add_argument("--study_config", type=str, required=True, help="Path to the HPO study YAML.")
    parser.add_argument("--num_gpus", type=int, default=2, help="Number of parallel workers (GPUs) to use for HPO trials.")
    parser.add_argument("--resume", action='store_true', help="Resume the HPO study from the last state.")

    # Suppressed arguments are for internal worker calls
    parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS)
    parser.add_argument("--device_id", type=int, help=argparse.SUPPRESS)
    parser.add_argument("--study_dir", type=str, help=argparse.SUPPRESS)
    parser.add_argument("--n_trials_per_worker", type=int, help=argparse.SUPPRESS)
    parser.add_argument("--final_base_config", type=str, help=argparse.SUPPRESS)
    cli_args, cli_overrides = parser.parse_known_args()

    with open(cli_args.study_config, 'r') as f:
        study_plan = yaml.safe_load(f)

    if cli_args.worker:
        run_worker(cli_args, study_plan)
        return

    base_config_raw = load_config(study_plan.get('base_config'), cli_overrides)
    base_config = process_derived_config(base_config_raw)
    study_dir = Path(base_config.run.log_dir) / study_plan['study_name']
    study_dir.mkdir(parents=True, exist_ok=True)
    base_config.run.log_dir = str(study_dir)

    final_base_config_path = study_dir / "hpo_final_base_config.yaml"
    with open(final_base_config_path, 'w') as f:
        yaml.dump(base_config.to_dict(), f)

    storage_name = f"sqlite:///{study_dir}/hpo_study.db"
    pruner_cfg = study_plan.get('pruner', {});
    pruner_cfg.pop('type', None)
    sampler_cfg = study_plan.get('sampler', {});
    sampler_cfg.pop('type', None)

    optuna.create_study(
        study_name=study_plan['study_name'], storage=storage_name, load_if_exists=True,
        direction=study_plan.get('direction', 'maximize'),
        sampler=optuna.samplers.TPESampler(**sampler_cfg),
        pruner=optuna.pruners.MedianPruner(**pruner_cfg)
    )

    n_gpus = cli_args.num_gpus
    total_trials = study_plan.get('n_trials', 10)
    trials_per_worker = (total_trials + n_gpus - 1) // n_gpus
    print(f"\n--- Starting HPO study across {n_gpus} GPUs ({trials_per_worker} trials each) ---")

    worker_logs_dir = study_dir / "_worker_logs";
    worker_logs_dir.mkdir(exist_ok=True)
    processes = []
    for i in range(n_gpus):
        cmd = [
            sys.executable, "-m", "user_extensions.experiments.run_hpo",
            "--study_config", cli_args.study_config, "--worker",
            "--device_id", str(i), "--study_dir", str(study_dir),
            "--n_trials_per_worker", str(trials_per_worker),
            "--final_base_config", str(final_base_config_path),
        ]
        if cli_args.resume:
            cmd.append('--resume')

        log_path = worker_logs_dir / f"worker_{i}.log"
        with open(log_path, 'w') as log_file:
            print(f"\n  🚀 Launching worker {i}. Log: tail -f {log_path}")
            proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT)
            processes.append(proc)
        time.sleep(2)

    for p in processes:
        p.wait()

    print("\n--- HPO Finished. ---")
    study = optuna.load_study(study_name=study_plan['study_name'], storage=storage_name)
    print(f"Best trial: {study.best_trial.value}")
    print(f"Best params: {study.best_params}")


if __name__ == "__main__":
    torch.set_float32_matmul_precision('high')
    main()