import argparse
import json
import math
import os
import subprocess
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, List, Tuple

import numpy as np
from tqdm import tqdm

from utils import generate_omega_grid, get_experiment_configs


def check_existing_result(
    omega: float, bc_type: str, method: str, seed: int, output_dir: str
) -> bool:
    omega_str = f"{omega:.4f}".replace(".", "p")
    filename = f"{omega_str}_{bc_type}_{method}_seed{seed}.json"
    filepath = os.path.join(output_dir, filename)
    return os.path.exists(filepath)


def run_single_experiment_subprocess(
    omega: float,
    bc_type: str,
    method: str,
    seed: int,
    gpu_id: int,
    output_dir: str,
    K: int = 6,
    total_steps: int = 20000,
    warmup_steps: int = 3000,
    ramp_steps: int = 5000,
    quiet: bool = True,
    skip_existing: bool = True,
) -> Tuple[str, bool, float, str]:
    exp_id = f"omega{omega:.4f}_{bc_type}_{method}_seed{seed}"

    if skip_existing and check_existing_result(
        omega, bc_type, method, seed, output_dir
    ):
        return exp_id, True, 0.0, "skipped"

    cmd = [
        sys.executable,
        "train_single.py",
        "--omega",
        str(omega),
        "--bc_type",
        bc_type,
        "--method",
        method,
        "--seed",
        str(seed),
        "--gpu",
        str(gpu_id),
        "--output_dir",
        output_dir,
        "--K",
        str(K),
        "--total_steps",
        str(total_steps),
        "--warmup_steps",
        str(warmup_steps),
        "--ramp_steps",
        str(ramp_steps),
    ]

    if quiet:
        cmd.append("--quiet")

    start_time = time.time()
    try:
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            cwd=os.path.dirname(os.path.abspath(__file__)),
            timeout=1200,
        )
        success = result.returncode == 0
        status = "completed" if success else "failed"
        if not success and not quiet:
            print(
                f"[ERROR] {exp_id}: {result.stderr[-500:] if result.stderr else 'No error message'}"
            )
    except subprocess.TimeoutExpired:
        success = False
        status = "timeout"
        if not quiet:
            print(f"[TIMEOUT] {exp_id}")
    except Exception as e:
        success = False
        status = "failed"
        if not quiet:
            print(f"[ERROR] {exp_id}: {str(e)}")

    elapsed = time.time() - start_time
    return exp_id, success, elapsed, status


def worker(args):
    return run_single_experiment_subprocess(*args)


def run_all_experiments(
    n_omega: int = 30,
    bc_types: Tuple[str, ...] = ("DD", "NN", "DN", "ND"),
    methods: Tuple[str, ...] = ("naive", "constraint"),
    seeds: Tuple[int, ...] = (0, 1, 2),
    n_gpus: int = 2,
    output_dir: str = "results",
    K: int = 6,
    total_steps: int = 20000,
    warmup_steps: int = 3000,
    ramp_steps: int = 5000,
    omega_min: float = math.pi / 2,
    omega_max: float = 11 * math.pi / 6,
    max_workers_per_gpu: int = 2,
    skip_existing: bool = True,
) -> Dict:
    configs = get_experiment_configs(
        n_omega=n_omega,
        bc_types=bc_types,
        methods=methods,
        seeds=seeds,
        omega_min=omega_min,
        omega_max=omega_max,
    )

    total_experiments = len(configs)
    print(f"\n{'='*70}")
    print(f"Exp7: Large-Scale Wedge/Corner Sweep")
    print(f"{'='*70}")
    print(f"Total experiments: {total_experiments}")
    print(
        f"  - Angles: {n_omega} (from {np.degrees(omega_min):.1f}° to {np.degrees(omega_max):.1f}°)"
    )
    print(f"  - BC types: {bc_types}")
    print(f"  - Methods: {methods}")
    print(f"  - Seeds: {seeds}")
    print(f"  - GPUs: {n_gpus}")
    print(
        f"  - K: {K}, steps: {total_steps}, warmup: {warmup_steps}, ramp: {ramp_steps}"
    )
    print(f"  - Output: {output_dir}")
    print(f"  - Skip existing: {skip_existing}")
    print(f"{'='*70}\n")

    os.makedirs(output_dir, exist_ok=True)

    tasks = []
    for i, config in enumerate(configs):
        gpu_id = i % n_gpus
        task = (
            config["omega"],
            config["bc_type"],
            config["method"],
            config["seed"],
            gpu_id,
            output_dir,
            K,
            total_steps,
            warmup_steps,
            ramp_steps,
            True,
            skip_existing,
        )
        tasks.append(task)

    results = []
    n_workers = n_gpus * max_workers_per_gpu
    start_time = time.time()

    n_completed = 0
    n_skipped = 0
    n_failed = 0

    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = {executor.submit(worker, task): task for task in tasks}

        with tqdm(total=total_experiments, desc="Running experiments") as pbar:
            for future in as_completed(futures):
                exp_id, success, elapsed, status = future.result()
                results.append(
                    {
                        "exp_id": exp_id,
                        "success": success,
                        "elapsed": elapsed,
                        "status": status,
                    }
                )

                if status == "completed":
                    n_completed += 1
                elif status == "skipped":
                    n_skipped += 1
                else:
                    n_failed += 1

                pbar.update(1)
                pbar.set_postfix(
                    {"completed": n_completed, "skipped": n_skipped, "failed": n_failed}
                )

    total_time = time.time() - start_time

    new_experiments = [r for r in results if r["status"] == "completed"]
    avg_time = (
        np.mean([r["elapsed"] for r in new_experiments]) if new_experiments else 0
    )

    summary = {
        "total_experiments": total_experiments,
        "completed": n_completed,
        "skipped": n_skipped,
        "failed": n_failed,
        "success_rate": (n_completed + n_skipped) / total_experiments * 100,
        "total_time_seconds": total_time,
        "avg_time_per_experiment": avg_time,
        "n_gpus": n_gpus,
        "K": K,
        "total_steps": total_steps,
        "warmup_steps": warmup_steps,
        "ramp_steps": ramp_steps,
        "timestamp": datetime.now().isoformat(),
    }

    summary_path = os.path.join(output_dir, "run_summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\n{'='*70}")
    print("EXPERIMENT SUMMARY")
    print(f"{'='*70}")
    print(f"  Total: {total_experiments}")
    print(f"  Completed (new): {n_completed}")
    print(f"  Skipped (existing): {n_skipped}")
    print(f"  Failed: {n_failed}")
    print(f"  Success rate: {summary['success_rate']:.1f}%")
    print(f"  Total time: {total_time/60:.1f} minutes")
    if new_experiments:
        print(f"  Avg time per new experiment: {avg_time:.1f} seconds")
    print(f"  Summary saved to: {summary_path}")
    print(f"{'='*70}\n")

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="Run all Exp7 experiments with multi-GPU"
    )
    parser.add_argument(
        "--n_omega", type=int, default=30, help="Number of wedge angles"
    )
    parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use")
    parser.add_argument(
        "--bc_types",
        nargs="+",
        default=["DD", "NN", "DN", "ND"],
        choices=["DD", "NN", "DN", "ND"],
    )
    parser.add_argument(
        "--methods",
        nargs="+",
        default=["naive", "constraint"],
        choices=["naive", "constraint"],
    )
    parser.add_argument("--seeds", nargs="+", type=int, default=[0, 1, 2])
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--K", type=int, default=6, help="Number of MSN terms")
    parser.add_argument("--total_steps", type=int, default=20000)
    parser.add_argument("--warmup_steps", type=int, default=3000)
    parser.add_argument("--ramp_steps", type=int, default=5000)
    parser.add_argument("--omega_min", type=float, default=math.pi / 2)
    parser.add_argument("--omega_max", type=float, default=11 * math.pi / 6)
    parser.add_argument(
        "--workers_per_gpu", type=int, default=2, help="Max concurrent workers per GPU"
    )
    parser.add_argument(
        "--no_skip", action="store_true", help="Do not skip existing experiments"
    )

    args = parser.parse_args()

    summary = run_all_experiments(
        n_omega=args.n_omega,
        bc_types=tuple(args.bc_types),
        methods=tuple(args.methods),
        seeds=tuple(args.seeds),
        n_gpus=args.n_gpus,
        output_dir=args.output_dir,
        K=args.K,
        total_steps=args.total_steps,
        warmup_steps=args.warmup_steps,
        ramp_steps=args.ramp_steps,
        omega_min=args.omega_min,
        omega_max=args.omega_max,
        max_workers_per_gpu=args.workers_per_gpu,
        skip_existing=not args.no_skip,
    )

    return summary


if __name__ == "__main__":
    main()
