import glob
from pathlib import Path
from fire import Fire
from collections import defaultdict


def generate_taskset_commands(
    download_folder: str,
    project_name: str,
    seed: int,
    radius: float,
    nb_core_max: int,
    core_per_taskset: int,
    filter_2d: bool = False,
):
    """
    Generates a bash script to run training processes for adversarial training using the taskset command to manage CPU cores.

    Args:
        download_folder (str): The directory containing subfolders with agent files.
        project_name (str): The name of the project, used in the output directory.
        seed (int): The seed for randomness in the training process.
        radius (float): The adversarial attack radius.
        nb_core_max (int): The maximum number of CPU cores available.
        core_per_taskset (int): The number of cores to use per taskset block.

    Returns:
        str: A formatted bash script with taskset commands to run training.
    """
    patterns = ["agent.pth", "policy.pth", "policies"]
    files = []
    for pattern in patterns:
        files.extend(glob.glob(f"{download_folder}/**/{pattern}", recursive=True))

    agents = []
    for path in files:
        path = Path(path)
        parent_path = path.parent
        folder_name = parent_path.name
        agent_type = "m2td3" if "m2td3" in folder_name else "td3"
        parts = folder_name.split("_")
        env_name = parts[2]
        nb_uncertainty_dim = int(parts[3])
        try:
            seed_agent = int(parts[-1])
        except ValueError:
            seed_agent = int(parts[-2])

        max_steps = 5000000 if nb_uncertainty_dim == 3 else 4000000
        experiment_name = f"{env_name}_{nb_uncertainty_dim}_{seed_agent}_{seed}"

        # This is the vanilla env, we need to run the 2D and 3D versions
        if nb_uncertainty_dim == 0:
            if not filter_2d:
                agents.append(
                    (path, agent_type, env_name, 2, max_steps, experiment_name)
                )
            agents.append((path, agent_type, env_name, 3, max_steps, experiment_name))
            continue

        if filter_2d and nb_uncertainty_dim == 2:
            continue

        agents.append(
            (path, agent_type, env_name, nb_uncertainty_dim, max_steps, experiment_name)
        )

    num_agents = len(agents)
    total_core_blocks = nb_core_max // core_per_taskset

    commands = defaultdict(list)
    for i, (
        agent_path,
        agent_type,
        env_name,
        nb_uncertainty_dim,
        max_steps,
        experiment_name,
    ) in enumerate(agents):
        block_index = i // (
            num_agents // total_core_blocks + (num_agents % total_core_blocks > 0)
        )
        core_start = block_index * core_per_taskset
        core_end = min(core_start + core_per_taskset - 1, nb_core_max - 1)
        command = (
            f"taskset -c {core_start}-{core_end} python src/main_tc_adversary.py "
            f'--agent_path="{agent_path}" --agent_type="{agent_type}" --env_name="{env_name}" '
            f"--nb_uncertainty_dim={nb_uncertainty_dim} --max_steps={max_steps} "
            f'--omniscient_adversary=True --project_name="{project_name}" --radius={radius} '
            f'--output_dir="result" --device="cuda:0" --experiment_name="{experiment_name}"'
        )
        commands[f"{core_start}-{core_end}"].append(command)

    # Build the output string with proper formatting
    result = []
    for core_range, cmds in commands.items():
        result.append(" ; ".join(cmds))

    return "\n\n".join(result)


if __name__ == "__main__":
    # python script/generate_adversary_script.py --download_folder="dl_rarl" --project_name="rarl_test" --seed=42 --radius=0.001 --nb_core_max=40 --core_per_taskset=4
    Fire(generate_taskset_commands)
