# file: user_extensions/experiments/run_ablation.py
import argparse
from collections import deque
import copy
import os
from pathlib import Path
import subprocess
import sys
import time
import yaml

import optuna
import torch

from prism.utils.config import apply_dict_overrides, AttrDict, load_config, process_derived_config
from user_extensions.experiments.utils import hpo_objective, run_trial, run_post_hoc_evaluation


def run_worker(args):
    with open(args.study_config, 'r') as f:
        study_plan = yaml.safe_load(f)
    with open(args.hpo_config, 'r') as f:
        hpo_plan = yaml.safe_load(f)

    hpo_temp_dir = Path(args.study_dir) / "_hpo_temp_logs"

    base_config_raw = load_config(study_plan.get('base_config', 'configs/base.yaml'))
    ablation_details = study_plan['trials'][args.ablation_name]
    search_space = hpo_plan['search_space']

    print(f"WORKER (PID:{os.getpid()}) for '{args.ablation_name}' on GPU:{args.device_id} started.")

    ablation_base_config_raw = AttrDict(copy.deepcopy(base_config_raw.to_dict()))
    apply_dict_overrides(ablation_base_config_raw, ablation_details.get('overrides', {}))

    ablation_base_config = process_derived_config(ablation_base_config_raw)

    CORE_HPARAMS = [
        'training.optimizer.main.lr',
        'training.optimizer.adversarial.lr',
        'training.optimizer.weight_decay',
    ]

    CONDITIONAL_HPARAMS_MAP = {
        'loss.weights.gamma_rec': 'loss.weights.gamma_rec',
        'loss.weights.gamma_cls': 'loss.weights.gamma_cls',
        'loss.weights.gamma_l': 'loss.weights.gamma_l',
        'loss.weights.gamma_info': 'loss.weights.gamma_info',
        'loss.weights.gamma_proto': 'loss.weights.gamma_proto',
        'loss.weights.gamma_gan': 'loss.weights.gamma_gan',
        'loss.prototype_momentum': 'loss.weights.gamma_proto',
        'loss.r1_penalty.gamma_r1': 'loss.weights.gamma_gan',
        'loss.weights.gamma_prior': 'loss.weights.gamma_info'
    }

    dynamic_search_space = {}
    for param in CORE_HPARAMS:
        if param in search_space:
            dynamic_search_space[param] = search_space[param]

    for param, controlling_gamma_path in CONDITIONAL_HPARAMS_MAP.items():
        if param not in search_space:
            continue
        keys = controlling_gamma_path.split('.')
        gamma_val = ablation_base_config
        try:
            for k in keys:
                gamma_val = gamma_val[k]
            if isinstance(gamma_val, (int, float)) and gamma_val > 0:
                dynamic_search_space[param] = search_space[param]
        except KeyError:
            continue

    hpo_db_dir = Path(args.study_dir) / "_hpo_dbs"
    hpo_db_dir.mkdir(exist_ok=True)
    storage_name = f"sqlite:///{hpo_db_dir}/{args.ablation_name}.db"
    study_name = args.ablation_name

    study = optuna.create_study(
        study_name=study_name,
        storage=storage_name,
        load_if_exists=True,
        direction="maximize",
    )

    n_completed = len(study.get_trials(states=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED)))
    n_trials_to_run = max(0, args.hpo_trials_per_ablation - n_completed)

    if n_trials_to_run > 0:
        print(f"[{args.ablation_name}] HPO study found. Completed {n_completed}/{args.hpo_trials_per_ablation} trials. Running {n_trials_to_run} more.")
        logger_info = {'save_dir': hpo_temp_dir, 'name': args.ablation_name, 'progress_bar': True}
        objective_fn = lambda trial: hpo_objective(
            trial, ablation_base_config, dynamic_search_space, logger_info, args.device_id
        )
        study.optimize(objective_fn, n_trials=n_trials_to_run, show_progress_bar=True)
    else:
        print(f"[{args.ablation_name}] HPO study already has {n_completed} trials. Skipping HPO phase.")

    print(f"[{args.ablation_name}] HPO complete. Launching final runs with best parameters.")

    if not study.best_trial:
        print(f"ERROR: No successful trials for '{args.ablation_name}'. Cannot perform final run.", file=sys.stderr)
        return

    best_trial = study.best_trial
    print(f"  [{args.ablation_name}] Best trial ({best_trial.number}) objective: {best_trial.value}")
    print(f"  [{args.ablation_name}] Using parameters: {best_trial.params}")

    final_config_dict = copy.deepcopy(ablation_base_config.to_dict())
    apply_dict_overrides(final_config_dict, best_trial.params)
    final_config_dict['run']['log_dir'] = str(args.study_dir)

    base_seed = final_config_dict.get('run', {}).get('seed')
    if not isinstance(base_seed, int):
        print(f"Warning: Base seed not found or invalid in config (path: run.seed). Defaulting to 2024.")
        base_seed = 2024

    num_final_runs = args.num_final_runs
    print(f"[{args.ablation_name}] Launching {num_final_runs} final runs with different seeds (base seed: {base_seed}).")

    for i in range(num_final_runs):
        run_seed = base_seed + i
        print(f"\n--- Starting final run {i + 1}/{num_final_runs} for '{args.ablation_name}' with seed {run_seed} ---")

        run_config_dict = copy.deepcopy(final_config_dict)
        run_config_dict.setdefault('run', {})['seed'] = run_seed
        run_config_dict['run']['sweep_name'] = f"{args.ablation_name}/seed_{run_seed}"

        run_trial(run_config_dict, args.max_retries, device_id=args.device_id)

    print(f"--- All {num_final_runs} final runs for '{args.ablation_name}' completed. ---")
    print(f"WORKER for '{args.ablation_name}' finished successfully.")


def main():
    parser = argparse.ArgumentParser(description="Run a parallel, fair ablation study.")
    # --- User-facing arguments for the orchestrator ---
    parser.add_argument("--study_config", type=str, required=True, help="Path to the ablation study YAML configuration file.")
    parser.add_argument("--hpo_config", type=str, required=True, help="Path to the HPO search space YAML file (e.g., hpo_loss.yaml).")
    parser.add_argument("--num_gpus", type=int, default=2, help="Total number of GPUs to use for parallel workers.")
    parser.add_argument("--max_retries", type=int, default=3, help="Retries for a failed final training run.")
    parser.add_argument("--skip_training", action='store_true', help="Skip all training, run only evaluation and analysis.")

    # --- Internal arguments for worker processes ---
    parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS)
    parser.add_argument("--ablation_name", type=str, help=argparse.SUPPRESS)
    parser.add_argument("--device_id", type=int, help=argparse.SUPPRESS)
    parser.add_argument("--study_dir", type=str, help=argparse.SUPPRESS)
    parser.add_argument("--hpo_trials_per_ablation", type=int, help=argparse.SUPPRESS)
    parser.add_argument("--num_final_runs", type=int, help=argparse.SUPPRESS)

    args, cli_overrides = parser.parse_known_args()

    if args.worker:
        run_worker(args)
        return

    with open(args.study_config, 'r') as f:
        study_plan = yaml.safe_load(f)

    num_final_runs = study_plan.get('num_final_runs', 1)
    if num_final_runs < 1:
        print("Warning: num_final_runs in study config is less than 1. Setting to 1.", file=sys.stderr)
        num_final_runs = 1

    base_config_raw = load_config(study_plan.get('base_config'), cli_overrides)
    base_config = process_derived_config(base_config_raw)
    study_dir = Path(base_config.run.log_dir) / study_plan['study_name']
    study_dir.mkdir(parents=True, exist_ok=True)

    if not args.skip_training:
        print("\n" + "=" * 60)
        print(f"STEP 1: LAUNCHING PARALLEL ABLATION WORKERS ACROSS {args.num_gpus} GPUS")
        print("=" * 60)

        ablation_configs = study_plan.get('trials', {})
        hpo_trials_per_ablation = study_plan['hpo_trials_per_ablation']
        worker_logs_dir = study_dir / "_worker_logs"
        worker_logs_dir.mkdir(exist_ok=True)

        active_processes = deque()
        ablation_names = list(ablation_configs.keys())

        for i, name in enumerate(ablation_names):
            if len(active_processes) >= args.num_gpus:
                oldest_proc, oldest_name = active_processes.popleft()
                oldest_proc.wait()

            device_id = i % args.num_gpus

            cmd = [
                sys.executable, "-m", "user_extensions.experiments.run_ablation",
                "--worker",
                "--study_config", args.study_config,
                "--hpo_config", args.hpo_config,
                "--ablation_name", name,
                "--device_id", str(device_id),
                "--study_dir", str(study_dir),
                "--hpo_trials_per_ablation", str(hpo_trials_per_ablation),
                "--max_retries", str(args.max_retries),
                "--num_final_runs", str(num_final_runs)
            ]

            log_path = worker_logs_dir / f"worker_{name}.log"

            with open(log_path, 'w') as log_file:
                print(f"\n  🚀 Launching worker for '{name}' on GPU {device_id}. Log: tail -f {log_path}")
                proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT)
                active_processes.append((proc, name))

            time.sleep(2)

        print("\n--- All workers launched. Waiting for the final batch to complete. ---")
        for proc, name in active_processes:
            proc.wait()

        print("\n--- All ablation workers have completed. ---")

    print("\n" + "=" * 60)
    print("STEP 2: RUNNING POST-HOC EVALUATION")
    print("=" * 60)
    run_post_hoc_evaluation(study_dir)


if __name__ == "__main__":
    torch.set_float32_matmul_precision('high')
    main()