# ruff: noqa: E501

import logging
import subprocess
import time

import pandas as pd

scenarios_to_run = [
    "con-10x10x10a",
    "con-15x15x23a",
    "large-4ag",
    "large-4ag-hard",
    "large-8ag",
    "large-8ag-hard",
    "medium-4ag",
    "medium-4ag-hard",
    "medium-6ag",
    "small-4ag",
    "small-4ag-hard",
    "tiny-2ag-hard",
    "tiny-4ag-hard",
    "xlarge-4ag",
    "xlarge-4ag-hard",
    "smacv2_10_units",
    "smacv2_20_units",
]
systems_to_run = [
    "rec_ippo",
    "rec_mappo",
    "rec_sable",
]
seed_string = "0,1,2,3,4"

NUM_UPDATES = 12200
NUM_EVALS = 1220

NEPTUNE_TAG = "to-delete"


def system_name_to_run_file(system_name: str) -> str:
    if system_name == "rec_sable":
        return "mava/systems/sable/anakin/rec_sable.py"
    elif system_name == "rec_mappo":
        return "mava/systems/ppo/anakin/rec_mappo.py"
    elif system_name == "rec_ippo":
        return "mava/systems/ppo/anakin/rec_ippo.py"


def get_system_script(hyperparams: pd.Series) -> str:
    script = ""
    script += f"system.num_minibatches={int(hyperparams['num_minibatches'])} \\\n"
    script += f"system.max_grad_norm={float(hyperparams['max_grad_norm'])} \\\n"
    script += f"system.ppo_epochs={int(hyperparams['ppo_epochs'])} \\\n"
    script += f"system.clip_eps={float(hyperparams['clip_eps'])} \\\n"
    script += f"system.ent_coef={float(hyperparams['ent_coef'])} \\\n"
    script += f"system.actor_lr={float(hyperparams['actor_lr'])} \\\n"

    if pd.notna(hyperparams["critic_lr"]):
        script += f"system.critic_lr={float(hyperparams['critic_lr'])} \\\n"

    return script


def get_env_script(env: str, scenario: str) -> str:
    env_script = f"env={env} \\\n"
    if env == "smax":
        scenario_script = f"env.scenario.task_name={scenario} \\\n"
    else:
        scenario_script = f"env/scenario={scenario} \\\n"

    return env_script + scenario_script


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    timestamp = time.strftime("%Y%m%d-%H%M%S")

    # each system has it's own dataframe
    dfs = {
        "on_policy": pd.read_csv("base_policy_hyperparameters/onpolicy_params.csv"),
    }

    # filter for systems and scenarios we want to run for each system's dataframe
    for system_type, df in dfs.items():
        mask = (df["system_name"].isin(systems_to_run)) & (df["task"].isin(scenarios_to_run))
        dfs[system_type] = df[mask]

    for df in dfs.values():
        for _, row in df.iterrows():
            system_name: str = row["system_name"]  # type: ignore
            scenario: str = row["task"]  # type: ignore
            env: str = row["env_name"]  # type: ignore
            # drop non-system params
            hyperparams = row.drop(["system_name", "task", "env_name", "compute"])

            init_script = (
                "#!/bin/bash\n"
                f"# Script auto-generated by experiment_dispatcher.py on {timestamp}\n"
                f"python {system_name_to_run_file(system_name)} -m \\\n"
                f"system.seed={seed_string} \\\n"
                f"logger.use_neptune=False \\\n"
                f"logger.kwargs.neptune_tag=[{NEPTUNE_TAG}] \\\n"
                f"arch.num_envs={int(hyperparams['num_envs'])} \\\n"
                f"system.num_updates={NUM_UPDATES} \\\n"
                f"arch.num_evaluation={NUM_EVALS} \\\n"
                f"system.rollout_length=128 \\\n"
            )

            system_script = get_system_script(hyperparams)
            env_script = get_env_script(env, scenario)
            script = init_script + system_script + env_script

            if system_name == "rec_sable":
                script += f"network.memory_config.decay_scaling_factor={float(hyperparams['decay_scaling_factor'])} \\\n"
                script += f"network.net_config.n_head={int(hyperparams['n_head'])} \\\n"
                script += f"network.net_config.n_block={int(hyperparams['n_block'])} \\\n"
                script += f"network.net_config.embed_dim={int(hyperparams['embed_dim'])} \\\n"

            if system_name in ["rec_mappo", "rec_ippo"]:
                script += (
                    f"system.recurrent_chunk_size={int(hyperparams['recurrent_chunk_size'])} \\\n"
                )

            script += "logger.checkpointing.load_model=False \\\n"
            script += "logger.checkpointing.download_model=False \\\n"
            script += "logger.checkpointing.upload_model=True \\\n"
            script += "logger.checkpointing.save_model=True \\\n"
            script += "logger.checkpointing.delete_local_checkpoints=False \\\n"
            script += "logger.checkpointing.save_args.checkpoint_uid=paper-baselines \\\n"

            with open("run.sh", mode="w+") as f:
                f.write(script)

            subprocess.run(["chmod", "+x", "run.sh"])
            subprocess.run(["./run.sh"])
