import argparse
from datetime import datetime
import math
import multiprocessing as mp
import os
import subprocess
import sys
from typing import Dict, List

import wandb


# Order: movement, standing, height, action_cost.
FIXED_REWARD_WEIGHT_CONFIG = [1.0, 0.005, 0.005, 0.0001]


def parse_cuda_devices(raw: str) -> List[int]:
    """Parse CUDA device input into a list of device indices.

    Args:
        raw: String like "0-3", "0,2,3", or "0".

    Returns:
        List of integer CUDA device indices.
    """
    raw = raw.strip()
    if "," in raw:
        return [int(item) for item in raw.split(",") if item]
    if "-" in raw:
        start, end = raw.split("-", maxsplit=1)
        return list(range(int(start), int(end) + 1))
    return [int(raw)]


def chunk_devices(devices: List[int], chunk_size: int) -> List[List[int]]:
    """Split devices into chunks for per-run GPU allocation.

    Args:
        devices: Flat list of visible GPU indices.
        chunk_size: Number of GPUs to allocate per sweep run.

    Returns:
        List of GPU index chunks.
    """
    return [devices[i : i + chunk_size] for i in range(0, len(devices), chunk_size)]


def build_sweep_config() -> Dict[str, Dict[str, Dict[str, List[float]]]]:
    """Create the sweep configuration for training hyperparameter grid search.

    Returns:
        Sweep configuration dictionary for wandb.sweep.
    """
    return {
        "method": "grid",
        "parameters": {
            "actor_lr": {"values": [1e-3, 3e-4, 1e-4]},
            "critic_lr": {"values": [1e-3, 3e-4, 1e-4]},
            "gae_lambda": {"values": [0.8, 0.9, 0.95, -1]},
            "entropy_weight": {"values": [-0.01]},
            "discount": {"values": [0.99]},
        },
    }


def count_sweep_runs(sweep_config: Dict[str, Dict]) -> int:
    """Compute total sweep runs from the grid configuration.

    Args:
        sweep_config: Sweep configuration dictionary.

    Returns:
        Total number of runs in the grid.
    """
    total = 1
    for param in sweep_config.get("parameters", {}).values():
        total *= len(param.get("values", []))
    return total


def run_agent_process(
    sweep_id: str,
    args: argparse.Namespace,
    project: str,
    gpu_group: List[int],
    worker_id: int,
    runs_per_worker: int,
    sweep_root_path: str,
) -> None:
    """Run a W&B agent in a separate process pinned to a GPU group.

    Args:
        sweep_id: W&B sweep identifier.
        args: Parsed CLI arguments.
        project: W&B project name.
        gpu_group: GPU indices assigned to this worker process.
        worker_id: Worker index for port offsetting.
        runs_per_worker: Max number of runs this worker should handle.
    """
    assigned_gpus = ",".join(map(str, gpu_group))
    os.environ["CUDA_VISIBLE_DEVICES"] = assigned_gpus
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

    local_run_index = 0
    port_stride = 1000
    main_script_path = os.path.join(os.path.dirname(__file__), "main.py")

    def run_sweep() -> None:
        """Run a single sweep trial by launching the training script."""
        nonlocal local_run_index
        mode = "online" if "WANDB_API_KEY" in os.environ else "offline"
        run = wandb.init(project=project, entity=args.entity, mode=mode)
        run_index = local_run_index
        local_run_index += 1

        nccl_port = args.base_nccl_port + worker_id * port_stride + run_index
        webserver_port = args.base_webserver_port + worker_id * port_stride + run_index

        # Keep reward weights fixed while sweeping training hyperparameters.
        reward_weight_config = list(FIXED_REWARD_WEIGHT_CONFIG)
        reward_weight_config_str = ",".join(map(str, reward_weight_config))
        actor_lr = float(run.config.actor_lr)
        critic_lr = float(run.config.critic_lr)
        gae_lambda = float(run.config.gae_lambda)
        entropy_weight = float(run.config.entropy_weight)
        discount = float(run.config.discount)
        run_name = f"{args.run_name_prefix}-{run.id}"
        run.name = run_name

        # Log sweep details before launching the training job.
        print(
            "[SWEEP] "
            f"run={run_name} "
            f"gpus={assigned_gpus} "
            f"nccl_port={nccl_port} "
            f"webserver_port={webserver_port} "
            f"max_epochs={args.max_epochs_per_sweep} "
            f"reward_weights={reward_weight_config_str} "
            f"actor_lr={actor_lr} "
            f"critic_lr={critic_lr} "
            f"gae_lambda={gae_lambda} "
            f"entropy_weight={entropy_weight} "
            f"discount={discount}"
        )
        run.config.update(
            {
                "allocated_gpus": assigned_gpus,
                "nccl_port": nccl_port,
                "webserver_port": webserver_port,
                "max_epochs_per_sweep": args.max_epochs_per_sweep,
                "reward_component_movement": reward_weight_config[0],
                "reward_component_standing": reward_weight_config[1],
                "reward_component_height": reward_weight_config[2],
                "reward_component_action_cost": reward_weight_config[3],
            },
            allow_val_change=True,
        )

        # Close the sweep run so the training job can resume it safely.
        run.finish()

        cmd = [
            sys.executable,
            main_script_path,
            "--run_name",
            run_name,
            "--nccl_port",
            str(nccl_port),
            "--webserver_port",
            str(webserver_port),
            "--ray_cuda_devices",
            assigned_gpus,
            "--reward_weight_config",
            reward_weight_config_str,
            "--actor_lr",
            str(actor_lr),
            "--critic_lr",
            str(critic_lr),
            "--gae_lambda",
            str(gae_lambda),
            "--entropy_weight",
            str(entropy_weight),
            "--discount",
            str(discount),
            "--max_epochs",
            str(args.max_epochs_per_sweep),
            "--root_path",
            sweep_root_path,
        ]
        env = os.environ.copy()
        env["WANDB_RUN_ID"] = run.id
        env["WANDB_RESUME"] = "allow"

        subprocess.run(cmd, check=True, env=env)

    wandb.agent(sweep_id, function=run_sweep, count=runs_per_worker)


def main() -> None:
    """Launch a W&B sweep and run grid search over training hyperparameters."""
    parser = argparse.ArgumentParser(
        description="Training hyperparameter sweep for manual robot parameter search."
    )
    parser.add_argument("--project", type=str, default=None)
    parser.add_argument("--entity", type=str, default=None)
    parser.add_argument("--run_name_prefix", type=str, default="sweep")
    parser.add_argument("--visible_gpus", type=str, default=None)
    parser.add_argument("--max_gpu_num_per_sweep", type=int, default=1)
    parser.add_argument("--max_epochs_per_sweep", type=int, default=200)
    parser.add_argument("--base_nccl_port", type=int, default=12345)
    parser.add_argument("--base_webserver_port", type=int, default=8002)
    args = parser.parse_args()

    project = args.project or os.environ.get(
        "WANDB_PROJECT", "endoskeletal-manual-robot-sweep"
    )
    visible_raw = args.visible_gpus or os.environ.get("CUDA_VISIBLE_DEVICES", "0")
    visible_gpus = parse_cuda_devices(visible_raw)
    if not visible_gpus:
        raise ValueError("No visible GPUs provided for sweep runs.")

    max_gpus = args.max_gpu_num_per_sweep
    if max_gpus <= 0:
        max_gpus = len(visible_gpus)
    if max_gpus > len(visible_gpus):
        raise ValueError(
            "max_gpu_num_per_sweep exceeds the available visible GPUs."
        )

    gpu_groups = chunk_devices(visible_gpus, max_gpus)

    sweep_config = build_sweep_config()
    sweep_id = wandb.sweep(
        sweep_config,
        project=project,
        entity=args.entity,
    )
    total_runs = count_sweep_runs(sweep_config)
    runs_per_worker = int(math.ceil(total_runs / len(gpu_groups)))
    sweep_start_time = datetime.now().strftime("%Y_%m_%d_%H_%M")
    sweep_root_path = os.path.join(
        "data/rl-result", f"sweep_{sweep_start_time}"
    )
    os.makedirs(sweep_root_path, exist_ok=True)
    print(f"[SWEEP] root_path={sweep_root_path}")

    ctx = mp.get_context("spawn")
    processes = []
    for worker_id, gpu_group in enumerate(gpu_groups):
        process = ctx.Process(
            target=run_agent_process,
            args=(
                sweep_id,
                args,
                project,
                gpu_group,
                worker_id,
                runs_per_worker,
                sweep_root_path,
            ),
        )
        process.start()
        processes.append(process)

    for process in processes:
        process.join()


if __name__ == "__main__":
    main()
